# tree_variational_autoencoders__fd257760.pdf Tree Variational Autoencoders Laura Manduchi , Moritz Vandenhirtz , Alain Ryser, Julia E. Vogt Department of Computer Science ETH Zurich Switzerland We propose Tree Variational Autoencoder (Tree VAE), a new generative hierarchical clustering model that learns a flexible tree-based posterior distribution over latent variables. Tree VAE hierarchically divides samples according to their intrinsic characteristics, shedding light on hidden structures in the data. It adapts its architecture to discover the optimal tree for encoding dependencies between latent variables. The proposed tree-based generative architecture enables lightweight conditional inference and improves generative performance by utilizing specialized leaf decoders. We show that Tree VAE uncovers underlying clusters in the data and finds meaningful hierarchical relations between the different groups on a variety of datasets, including real-world imaging data. We present empirically that Tree VAE provides a more competitive log-likelihood lower bound than the sequential counterparts. Finally, due to its generative nature, Tree VAE is able to generate new samples from the discovered clusters via conditional sampling. 1 Introduction Discovering structure and hierarchies in the data has been a long-standing goal in machine learning (Bishop, 2006; Bengio et al., 2012; Jordan & Mitchell, 2015). Interpretable supervised methods, such as decision trees (Zhou & Feng, 2017; Tanno et al., 2019), have proven to be successful in unveiling hierarchical relationships within data. However, the expense of annotating large quantities of data has resulted in a surge of interest in unsupervised approaches (Le Cun et al., 2015). Hierarchical clustering (Ward, 1963) offers an unsupervised path to find hidden groups in the data and their hierarchical relationship (R. J. G. B. Campello et al., 2015). Due to its versatility, interpretability, and ability to uncover meaningful patterns in complex data, hierarchical clustering has been widely used in a variety of applications, including phylogenetics (Sneath & Sokal, 1962), astrophysics (Mc Connachie et al., 2018), and federated learning (Briggs et al., 2020). Similar to how the human brain automatically categorizes and connects objects based on shared attributes, hierarchical clustering algorithms construct a dendrogram - a tree-like structure of clusters - that organizes data into nested groups based on their similarity. Despite its potential, hierarchical clustering has taken a step back in light of recent advances in self-supervised deep learning (Chen et al., 2020), and only a few deep learning based methods have been proposed in recent years (Goyal et al., 2017; Mautz et al., 2020). Deep latent variable models (Kingma & Welling, 2019), a class of generative models, have emerged as powerful frameworks for unsupervised learning and they have been extensively used to uncover hidden structures in the data (Dilokthanakul et al., 2016; Manduchi et al., 2021). They leverage the flexibility of neural networks to capture complex patterns and generate meaningful representations of high-dimensional data. By incorporating latent variables, these models can uncover the underlying factors of variation of the data, making them a valuable tool for understanding and modeling complex data distributions. In recent years, a variety of deep generative methods have been proposed to incorporate more complex posterior distributions by modeling structural sequential dependencies Equal contribution. Correspondence to {laura.manduchi,moritz.vandenhirtz}@inf.ethz.ch 37th Conference on Neural Information Processing Systems (Neur IPS 2023). horses dogs birds cats frogs deers trucks cars ships planes animals vehicles Figure 1: The hierarchical structure discovered by Tree VAE on the CIFAR-10 dataset. We display random subsets of images that are probabilistically assigned to each leaf of the tree. between latent variables (Sønderby et al., 2016; He et al., 2018; Maaløe et al., 2019; Vahdat & Kautz, 2020a), thus offering different levels of abstraction for encoding the data distribution. Our work advances the state-of-the-art in structured VAEs by combining the complementary strengths of hierarchical clustering algorithms and deep generative models. We propose Tree VAE1, a novel tree-based generative model that encodes hierarchical dependencies between latent variables. We introduce a training procedure to learn the optimal tree structure to model the posterior distribution of latent variables. An example of a tree learned by Tree VAE is depicted in Fig. 1. Each edge and split are encoded by neural networks, while the circles depict latent variables. Each sample is associated with a probability distribution over paths. The resulting tree thus organizes the data into an interpretable hierarchical structure in an unsupervised fashion, optimizing the amount of shared information between samples. In CIFAR-10, for example, the method divides the vehicles and animals into two different subtrees and similar groups (such as planes and ships) share common ancestors. Our main contributions are as follows: (i) We propose a novel, deep probabilistic approach to hierarchical clustering that learns the optimal generative binary tree to mimic the hierarchies present in the data. (ii) We provide a thorough empirical assessment of the proposed approach on MNIST, Fashion-MNIST, 20Newsgroups, and Omniglot. In particular, we show that Tree VAE (a) outperforms related work on deep hierarchical clustering, (b) discovers meaningful patterns in the data and their hierarchical relationships, and (c) achieves a more competitive log-likelihood lower bound compared to VAE and Ladder VAE, its sequential counterpart. (iii) We propose an extension of Tree VAE that integrates contrastive learning into its tree structure. Relevant prior knowledge, expertise, or specific constraints can be incorporated into the generative model via augmentations, allowing for more accurate and contextually meaningful clustering. We test the contrastive version of Tree VAE on CIFAR-10, CIFAR-100, and Celeb A, and we show that the proposed approach achieves competitive hierarchical clustering performance compared to the baselines. We propose Tree VAE, a novel deep generative model that learns a flexible tree-based posterior distribution over latent variables. Each sample travels through the tree from root to leaf in a probabilistic manner as Tree VAE learns sample-specific probability distributions of paths. As a result, the data is divided in a hierarchical fashion, with more refined concepts for deeper nodes in the tree. The proposed graphical model is depicted in Fig. 2. The inference and generative models share the same top-down tree structure, enabling interaction between the bottom-up and top-down architecture, similarly to Sønderby et al. (2016). 2.1 Model Formulation Given H, the maximum depth of the tree, and a dataset X, the model is defined by three components that are learned during training: 1The code is publicly available at https://github.com/lauramanduchi/treevae-pytorch. the global structure of the binary tree T , which specifies the set of nodes V = {0, . . . , V }, the set of leaves L, where L V, and the set of edges E. See Fig. 1/4/5/6/7 for different examples of tree structures learned by the model. the sample-specific latent embeddings z = {z0, . . . , z V }, which are random variables assigned to each node in V. Each embedding is characterized by a Gaussian distribution whose parameters are a function of the realization of the parent node. The dimensions of the latent embeddings are defined by their depth, with zi Rhdepth(i) where depth(i) is the depth of the node i, and hdepth(i) is the embedding dimension for that depth. the sample-specific decisions c = {c0, . . . , c V |L|}, which are Bernoulli random variables defined by the probability of going to the right (or left) child of the underlying node. They take values ci {0, 1} for i V \ L, with ci = 0 if the left child is selected. A decision path, Pl, indicates the path from root to leaf given the tree T and is defined by the nodes in the path, e.g., in Fig. 2, Pl = {0, 1, 4, 5}. The probability of Pl is the product of the probabilities of the decisions in the path. The tree structure is shared across the entire dataset and is learned iteratively by growing the tree node-wise. The latent embeddings and the decision paths, on the other hand, are learned using variational inference by conditioning the model on the current tree structure. The generative/inference model and the learning objective conditioned on T are explained in Sec. 2.2/2.3/2.4 respectively, while in 2.5, we elaborate on the efficient growing procedure of the tree. 2.2 Generative Model Figure 2: The proposed inference (left) and generative (right) models for Tree VAE. Circles are stochastic variables while diamonds are deterministic. The global topology of the tree is learned during training. The generative process of Tree VAE for a given T is depicted in Fig. 2 (right). The generation of a new sample x starts from the root. First, the latent embedding of the root node z0 is sampled from a standard Gaussian pθ (z0) = N (z0 | 0, I). Then, given the sampled z0, the decision of going to the left or the right node is sampled from a Bernoulli distribution p(c0 | z0) = Ber(rp,0(z0)), where {rp ,i | i V \ L} are functions parametrized by neural networks defined as routers, and cause the splits in Fig. 2. The subscript p is used to indicate the parameters of the generative model. The latent embedding of the selected child, let us assume it is z1, is then sampled from a Gaussian distribution pθ(z1 | z0) = N z1 | µp,1 (z0) , σ2 p,1 (z0) , where {µp,i, σp,i | i V \ {0}} are functions parametrized by neural networks defined as transformations. They are indicated by the top-down arrows in Fig. 2. This process continues until a leaf is reached. Let us define the set of latent variables selected by the path Pl, which goes from the root to the leaf l, as z Pl = {zi | i Pl}, the parent node of the node i as pa(i), and p(cpa(i) i | zpa(i)) the probability of going from pa(i) to i. Note that the path Pl defines the sequence of decisions. The prior probability of the latent embeddings and the path given the tree T can be summarized as pθ(z Pl, Pl) = p(z0) Y i Pl\{0} p(cpa(i) i | zpa(i))p(zi | zpa(i)). (1) Finally, x is sampled from a distribution that is conditioned on the selected leaf. If we assume that x is real-valued, then pθ(x | z Pl, Pl) = N x | µx,l (zl) , σ2 x,l (zl) , (2) where {µx,l, σx,l | l L} are functions parametrized by leaf-specific neural networks defined as decoders. 2.3 Inference Model The inference model is described by the variational posterior distribution of both the latent embeddings and the paths. It follows a similar structure as in the prior probability defined in (1), with the difference that the probability of the root and of the decisions are now conditioned on the sample x: q(z Pl, Pl | x) = q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)). (3) To compute the variational probability distribution of the latent embeddings q(z0 | x) and q(zi | zpa(i)), where q(z0 | x) = N z0 | µq,0(x), σ2 q,0(x) (4) qϕ zi | zpa(i) = N zi | µq,i zpa(i) , σ2 q,i zpa(i) , i Pl, (5) we follow a similar approach to the one proposed by Sønderby et al. (2016). Note that we use the subscript q to indicate the parameters of the inference model. First, a deterministic bottom-up pass computes the node-specific approximate likelihood contributions dh = MLP (dh+1) (6) ˆµq,i = Linear ddepth(i) , i V (7) ˆσ2 q,i = Softplus Linear ddepth(i) , i V, (8) where d H is parametrized by a domain-specific neural network defined as encoder, and MLP(dh) for h {1, . . . , H}, indicated by the bottom-up arrows in Fig. 2, are neural networks, shared among the parameter predictors, ˆµq,i, ˆσ2 q,i, of the same depth. They are characterized by the same architecture as the transformations defined in Sec.2.2. A stochastic downward pass then recursively computes the approximate posteriors defined as σ2 q,i = 1 ˆσ 2 q,i + σ 2 p,i , µq,i = ˆµq,iˆσ 2 q,i + µp,iσ 2 p,i ˆσ 2 q,i + σ 2 p,i , (9) where all operations are performed elementwise. Finally, the variational distributions of the decisions q(ci | x) are defined as q(ci | x) = q(ci | ddepth(i)) = Ber(rq,i(ddepth(i))), (10) where {rq,i | i V \ L} are functions parametrized by neural networks and are characterized by the same architecture as the routers of the generative model defined in Sec. 2.2. 2.4 Evidence Lower Bound The parameters of both the generative model (defined as p) and inference model (defined as q), consisting of the encoder (µq,0, σq,0), the transformations ({(µp,i, σp,i), (µq,i, σq,i) | i V \ {0}}), the decoders ({µx,l, σx,l | l L}) and the routers ({rp,i, rq,i | i V\L}), are learned by maximizing the Evidence Lower Bound (ELBO) (Kingma & Welling, 2014; Rezende et al., 2014). Each leaf l is associated with only one path Pl, hence we can write the data likelihood conditioned on T as p(x | T ) = X z Pl p(x, z Pl, Pl) = X z Pl pθ(z Pl, Pl)pθ(x | z Pl, Pl). (11) We use variational inference to derive the ELBO of the log-likelihood: L(x | T ) := Eq(z Pl,Pl|x)[log p(x | z Pl, Pl)] KL (q (z Pl, Pl | x) p (z Pl, Pl)) . (12) The first term of the ELBO is the reconstruction term: Lrec = Eq(z Pl,Pl|x)[log p(x | z Pl, Pl)] (13) z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log p(x | z Pl, Pl) (14) l L P(l; c) log N x | µx,l z(m) l , σ2 x,l z(m) l , (15) P(i; c) = Y j Pi\{0} q(cpa(j) j | x) for i V, (16) where Pi for i V is the path from root to node i, P(i; c) is the probability of reaching node i, which is the product over the probabilities of the decisions in the path until i, z(m) l are the Monte Carlo (MC) samples, and M the number of the MC samples. Intuitively, the reconstruction loss is the sum of the leaf-wise reconstruction losses weighted by the probabilities of reaching the respective leaf. Note that here we sum over all possible paths in the tree, which is equal to the number of leaves. The second term of (12) is the Kullback Leibler divergence (KL) between the prior and the variational posterior of the tree. It can be written as a sum of the KL of the root, the nodes, and the decisions: KL (q (z Pl, Pl | x) p (z Pl, Pl)) = KLroot + KLnodes + KLdecisions (17) KLroot = KL(q(z0 | x) p(z0)) (18) i V\{0} P(i; c) KL(q(z(m) i | pa(z(m) i )) p(z(m) i | pa(z(m) i ))) (19) KLdecisions 1 i V\L P(i; c) X ci {0,1} q(ci | x) log p(ci | z(m) i ) where M is the number of MC samples. We refer to Appendix A for the full derivation. The KLroot is the KL between the standard Gaussian prior p(z0) and the variational posterior of the root q(z0 | x), thus enforcing the root to be compact. The KLnodes is the sum of the nodespecific KLs weighted by the probability of reaching their node i: P(i; c). The node-specific KL of node i is the KL between the two Gaussians q(zi | pa(zi)), p(zi | pa(zi)). Finally, the last term, KLdecisions, is the weighted sum of all the KLs of the decisions, which are Bernoulli random variables, KL(q(ci | x) | p(ci | zi))) = P ci {0,1} q(ci | x) log q(ci|x) p(ci|zi)) . The hierarchical specification of the binary tree allows encoding highly expressive models while retaining the computational efficiency of fully factorized models. The computational complexity is described in Appendix A.2. 2.5 Growing The Tree Step 1 Step 2 Figure 3: The first two steps of the growing process to learn the global structure of the tree during training. Highlighted in red are the trainable weights. In the previous sections, we discussed the variational objective to learn the parameters of both the generative and the inference model given a defined tree structure T . Here we discuss how to learn the structure of the binary tree T . Tree VAE starts by training a tree composed of a root and two leaves, see Fig. 3 (left), for Nt epochs by optimizing the ELBO. Once the model converged, a leaf is selected, e.g., z1 in Fig. 3, and two children are attached to it. The leaf selection criteria can vary depending on the application and can be determined by, e.g., the reconstruction loss or the ELBO. In our experiments, we chose to select the nodes with the maximum number of samples to retain balanced leaves. The sub-tree composed of the new leaves and the parent node is then trained for Nt epochs by freezing the weights of the rest of the model, see Fig. 3 (right), resulting in computing the ELBO of the nodes of the subtree. For efficiency, the subtree is trained using only the subset of data that have a high probability (higher than a threshold t) of being assigned to the parent node. The process is repeated until the tree reaches its maximum capacity (defined by the maximum depth) or until a condition (such as a predefined maximum number of leaves) is met. The entire model is then fine-tuned for Nf epochs by unfreezing all weights. During fine-tuning, the tree is pruned by removing empty branches (with the expected number of assigned samples lower than a threshold). 2.6 Integrating Prior Knowledge Retrieving semantically meaningful clustering structures of real-world images is extremely challenging, as there are several underlying factors according to which the data can be clustered. Therefore, it is often crucial to integrate domain knowledge that guides the model toward desirable cluster assignments. Thus, we propose an extension of Tree VAE where we integrate recent advances in contrastive learning (van den Oord et al., 2018; Chen et al., 2020; Li et al., 2021), whereby prior knowledge on data invariances can be encoded through augmentations. For a batch X with N samples, we randomly augment every sample twice to obtain the augmented batch X with 2N samples. For all positive pairs (i, j) where xi and xj stem from the same original sample, we utilize the NT-Xent (Chen et al., 2020), which introduces losses ℓi,j = log exp (si,j/τ) P2N k=1 1[k =i] exp (si,k/τ), where si,j denotes the cosine similarity between the representations of xi and xj, and τ is a temperature parameter. We integrate ℓi,j in both the bottom-up and the routers of Tree VAE. In the bottom-up, similar to Chen et al. (2020), we compute ℓi,j on the projections gh(dh). For the routers, we directly compute the loss on the predicted probabilities rq,i(dh). Finally, we average the terms over all positive pairs and add them to the negative ELBO (12) in real-world image experiments. Implementation details can be found in Appendix E, while a loss ablation is shown in Appendix C.3. 3 Related Work Deep latent variable models automatically learn structure from data by combining the flexibility of deep neural networks and the statistical foundations of generative models (Mattei & Frellsen, 2018). Variational autoencoders (VAEs) (Rezende et al., 2014; Kingma & Welling, 2014) are among the most used frameworks (Nasiri & Bepler, 2022; Bae et al., 2023; Bredell et al., 2023). A variety of works has been proposed to integrate more complex empirical prior distributions, thus reducing the gap between approximate and true posterior distributions (Ranganath et al., 2015; Webb et al., 2017; Klushyn et al., 2019). Among these, the most related to our work is the VAE-n CRP (Goyal et al., 2017; Shin et al., 2019) and the TMC-VAE (Vikram et al., 2018). Both works use Bayesian nonparametric hierarchical clustering based on the nested Chinese restaurant process (n CRP) prior (Blei et al., 2003), and on the time-marginalized coalescent (TMC). However, even if they allow more flexible prior distributions these models suffer from restrictive posterior distributions (Kingma et al., 2016).To overcome the above issue, deep hierarchical VAEs (Gregor et al., 2015; Kingma et al., 2016) have been proposed to employ structured approximate posteriors, which are composed of hierarchies of conditional stochastic variables that are connected sequentially. Among a variety of proposed methods (Vahdat & Kautz, 2020b; Falck et al., 2022; T. Z. Xiao & Bamler, 2023), Ladder VAE (Sønderby et al., 2016) is most related to Tree VAE. The authors propose to model the approximate posterior by combining a bottom-up recognition distribution with the top-down prior. Further extensions include BIVA (Maaløe et al., 2019), which introduces a bidirectional inference network, and Graph VAE (He et al., 2019), that introduces gated dependencies over a fixed number of latent variables. Contrary to the previous approaches, Tree VAE models a tree-based posterior distribution of latent variable, thus allowing hierarchical clustering of samples. For further work on hierarchical clustering and its supervised counterpart, decision trees, we refer to Appendix B. 4 Experimental Setup Datasets and Metrics: We evaluate the clustering and generative performance of Tree VAE on MNIST (Le Cun et al., 1998), Fashion-MNIST (H. Xiao et al., 2017), 20Newsgroups (Lang, 1995), Omniglot (Lake et al., 2015), and Omniglot-5, where only 5 vocabularies (Braille, Glagolitic, Cyrillic, Odia, and Bengali) are selected and used as true labels. We assess the hierarchical clustering performance by computing dendrogram purity (DP) and leaf purity (LP), as defined by (Kobren et al., 2017a) using the datasets labels, where we assume the number of true clusters is unknown. We also report standard clustering metrics, accuracy (ACC) and normalized mutual information (NMI), by setting the number of leaves for Tree VAE and for the baselines to the true number of clusters. In terms of generative performance, we compute the approximated true log-likelihood calculated using 1000 importance-weighted samples, together with the ELBO (12) and the reconstruction loss (16). We also perform hierarchical clustering experiments on real-world imaging data, namely CIFAR-10, CIFAR-100 (Krizhevsky & Hinton, 2009) with 20 superclasses as labels, and Celeb A (Z. Liu et al., 2015) using the contrastive extension (Sec. 2.6). We refer to Appendix D for more dataset details. Baselines: We compare the generative performance of Tree VAE to the VAE (Rezende et al., 2014; Kingma & Welling, 2014), its non-hierarchical counterpart, and the Ladder VAE (Sønderby et al., 2016), its sequential counterpart. For a fair comparison, all methods share the same architecture and Table 1: Test set hierarchical clustering performances (%) of Tree VAE compared with baselines. Means and standard deviations are computed across 10 runs with different random model initialization. The star "*" indicates real-world image datasets on which contrastive approaches were applied. Dataset Method DP LP ACC NMI MNIST Agg 63.7 0.0 78.6 0.0 69.5 0.0 71.1 0.0 VAE + Agg 79.9 2.2 90.8 1.4 86.6 4.9 81.6 2.0 Ladder VAE + Agg 81.6 3.9 90.9 2.5 80.3 5.6 82.0 2.1 Deep ECT 74.6 5.9 90.7 3.2 74.9 6.2 76.7 4.2 Tree VAE (ours) 87.9 4.9 96.0 1.9 90.2 7.5 90.0 4.6 Fashion Agg 45.0 0.0 67.6 0.0 51.3 0.0 52.6 0.0 VAE + Agg 44.3 2.5 65.9 2.3 54.9 4.4 56.1 3.2 Ladder VAE + Agg 49.5 2.3 67.6 1.2 55.9 3.0 60.7 1.4 Deep ECT 44.9 3.3 67.8 1.4 51.8 5.7 57.7 3.7 Tree VAE (ours) 54.4 2.4 71.4 2.0 63.6 3.3 64.7 1.4 20Newsgroups Agg 13.1 0.0 30.8 0.0 26.1 0.0 27.5 0.0 VAE + Agg 7.1 0.3 18.1 0.5 15.2 0.4 11.6 0.3 Ladder VAE + Agg 9.0 0.2 20.0 0.7 17.4 0.9 17.8 0.6 Deep ECT 9.3 1.8 17.2 3.8 15.6 3.0 18.1 4.1 Tree VAE (ours) 17.5 1.5 38.4 1.6 32.8 2.3 34.4 1.5 Omniglot-5 Agg 41.4 0.0 63.7 0.0 53.2 0.0 33.3 0.0 VAE + Agg 46.3 2.3 68.1 1.6 52.9 4.2 34.4 2.9 Ladder VAE + Agg 49.8 3.9 71.3 2.0 59.6 4.9 44.2 4.7 Deep ECT 33.3 2.5 55.1 2.8 41.1 4.2 23.5 4.3 Tree VAE (ours) 58.8 4.0 77.7 3.9 63.9 7.0 50.0 5.9 CIFAR-10* VAE + Agg 10.54 0.12 16.33 0.15 14.43 0.19 1.86 1.66 Ladder VAE + Agg 12.81 0.20 25.37 0.62 19.29 0.60 7.41 0.42 Deep ECT 10.01 0.02 10.30 0.40 10.31 0.39 0.18 0.10 Tree VAE (ours) 35.30 1.15 53.85 1.23 52.98 1.34 41.44 1.13 CIFAR-100* VAE + Agg 5.27 0.02 9.86 0.19 8.82 0.11 2.46 0.10 Ladder VAE + Agg 6.36 0.07 16.08 0.28 14.01 0.41 8.99 0.41 Deep ECT 5.28 0.18 6.97 0.69 6.97 0.69 1.71 0.86 Tree VAE (ours) 10.44 0.38 24.16 0.65 21.82 0.77 17.80 0.42 hyperparameters whenever possible. We compare Tree VAE to non-generative hierarchical clustering baselines for which the code was publicly available: Ward s minimum variance agglomerative clustering (Agg) (Ward, 1963; Murtagh & Legendre, 2014), and the Deep ECT (Mautz et al., 2020). We propose two additional baselines, where we perform Ward s agglomerative clustering on the latent space of the VAE (VAE + Agg) and of the last layer of the Ladder VAE (Ladder VAE + Agg). For the contrastive clustering experiments, we apply a contrastive loss similar to Tree VAE to the VAE and the Ladder VAE, while for Deep ECT we use the contrastive loss proposed by the authors. Implementation Details: While we believe that more complex architectures could have a substantial impact on the performance of Tree VAE, we choose to employ rather simple settings to validate the proposed approach. We set the dimension of all latent embeddings z = {z0, . . . , z V } to 8 for MNIST, Fashion, and Omniglot, to 4 for 20Newsgroups, and to 64 for CIFAR-10, CIFAR-100, and Celeb A. The maximum depth of the tree is set to 6 for all datasets, except 20Newsgroups where we increased the depth to 7 to capture more clusters. To compute DP and LP, we allow the tree to grow to a maximum of 30 leaves for 20Newsgroups and CIFAR-100, and 20 for the rest, while for ACC and NMI we fix the number of leaves to the number of true classes. The transformations consist of one-layer MLPs of size 128 and the routers of two-layers of size 128 for all datasets except for the real-world imaging data where we slightly increase the MLP complexity to 512. Finally, the encoder and decoders consist of simple CNNs and MLPs. The trees are trained for Nt = 150 epochs at each growth step, and the final tree is finetuned for Nf = 200 epochs. For the real-world imaging experiments, we set the weight of the contrastive loss to 100. See Appendix E for additional details. Table 2: Test set generative performances of Tree VAE with 10 leaves compared with baselines. Means and standard deviations are computed across 10 runs with different random model initialization. Dataset Method LL RL ELBO MNIST VAE 101.9 0.2 87.2 0.3 104.6 0.3 Ladder VAE 99.9 0.5 87.8 0.7 103.2 0.7 Tree VAE (ours) 92.9 0.2 80.3 0.2 96.8 0.2 Fashion VAE 242.2 0.2 231.7 0.5 245.4 0.5 Ladder VAE 239.4 0.5 231.5 0.6 243.0 0.6 Tree VAE (Ours) 234.7 0.1 226.5 0.3 239.2 0.4 20Newsgroups VAE 44.26 0.01 45.52 0.03 44.61 0.01 Ladder VAE 44.30 0.03 43.52 0.03 44.62 0.02 Tree VAE (Ours) 51.67 0.59 45.83 0.36 52.79 0.66 Omniglot VAE 115.3 0.3 101.6 0.3 118.2 0.3 Ladder VAE 113.1 0.5 100.7 0.7 117.5 0.6 Tree VAE (Ours) 110.4 0.5 96.9 0.5 114.6 0.4 Hierarchical Clustering Results Table 1 shows the quantitative hierarchical clustering results averaged across 10 seeds. First, we assume the true number of clusters is unknown and report DP and LP. Second, we assume we have access to the true number of clusters K and compute ACC and NMI. As can be seen, Tree VAE outperforms the baselines in both experiments. This suggests that the proposed approach successfully builds an optimal tree based on the data s intrinsic characteristics. Among the different baselines, agglomerative clustering using Ward s method (Agg) trained on the last layer of Ladder VAE shows competitive performances. To the best of our knowledge, we are the first to report these results. It is noteworthy to observe that it consistently improves over VAE + Agg, indicating that the last layer of Ladder VAE captures more cluster information than the VAE. Generative Results In Table 2, we evaluate the generative performance of the proposed approach, Tree VAE, compared to the VAE, its non-hierarchical counterpart, and Ladder VAE, its sequential counterpart. Tree VAE outperforms the baselines on the majority of datasets, indicating that the proposed ELBO (12) can achieve a tighter lower bound of the log-likelihood. The most notable improvement appears to be reflected in the reconstruction loss, showing the advantage of using cluster-specialized decoders. However, this improvement comes at the expense of a larger neural network architecture and an increase in the number of parameters (as Tree VAE has L decoders). While this requires more computational resources at training time, during deployment the tree structure of Tree VAE permits lightweight inference through conditional sampling, thus matching the inference time of Ladder VAE. It is also worth mentioning that results differ from (Sønderby et al., 2016) as we adapt their architecture to match our experimental setting and consequently use smaller latent dimensionality. Finally, we notice that more complex methods are prone to overfitting on the 20Newsgroups dataset, hence the best performances are achieved by the VAE. Real-world Imaging Data & Contrastive Learning Clustering real-world imaging data is extremely difficult as there are endless possibilities of how the data can be partitioned (such as the colors, the landscape, etc). We therefore inject prior information through augmentations to guide Tree VAE and the baselines to semantically meaningful splits. Table 1 (bottom) shows the hierarchical clustering performance of Tree VAE and its baselines, all employing contrastive learning, on CIFAR-10 and CIFAR-100. We observe that Deep ECT struggles in separating the data as their contrastive approach leads to all samples falling into the same leaf. In Table 3, we present the leaf-frequencies of various face attributes using the tree learned by Tree VAE. For all datasets, Tree VAE is able to group the data into contextually meaningful hierarchies and groups, evident from its superior performance compared to the baselines and from the distinct attribute frequencies in the leaves and subtrees. Discovery of Hierarchies In addition to solely clustering data, Tree VAE is able to discover meaningful hierarchical relations between the clusters, thus allowing for more insights into the Figure 4: Hierarchical structures learned by Tree VAE on Fashion. Subtree (a) encodes tops, while (b) encodes shoes, purses, and pants. comp.windows.x comp.graphics comp.os.ms-windows.misc comp.sys.ibm.pc.hardware misc.forsale comp.sys.mac.hardware rec.motorcycles rec.autos / motorcycles misc.forsale alt.atheism talk. religion.misc rec.sport.baseball talk.politics.guns / misc talk.politics.mideast soc.religion.christian soc.religion.christian rec.sport.hockey Figure 5: Hierarchical structure learned by Tree VAE on 20Newsgroups. Figure 6: Hierarchical structures learned by Tree VAE on Omniglot-5. Subtree (a) learns a hierarchy over Braille and the Indian alphabets, while (b) groups Slavic alphabets. Attribute 1 2 3 4 5 6 7 8 Female 97.2 55.0 97.7 86.6 23.1 30.7 46.6 43.7 Bangs 1.6 1.2 24.1 61.7 3.4 11.1 9.1 11.4 Blonde 1.1 3.7 66.7 2.2 5.8 2.6 26.1 7.1 Makeup 75.7 43.4 76.6 59.7 15.0 12.4 16.3 12.8 Smiling 54.3 66.6 66.4 51.2 54.7 42.4 37.3 22.4 Hair Loss 3.6 17.8 3.0 0.2 18.9 6.9 21.2 10.6 Beard 1.1 20.6 0.4 3.7 39.5 36.5 21.3 21.4 Table 3: We present the frequency (in %) of selected attributes for each leaf of Tree VAE with eight leaves in Celeb A. Figure 7: Hierarchical structure learned by Tree VAE with eight leaves on the Celeb A dataset with generated images through conditional sampling. Generally, most females are in the left subtree, while most males are in the right subtree. We observe that leaf 1 is associated with dark-haired females, leaf 2 with smiling, dark-haired individuals, leaf 3 with blonde females, leaf 4 with bangs, leaf 7 with a receding hairline, and leaf 8 with non-smiling people. See Table 3 for further details. dataset. In the introductory Fig. 1, 5, and 6, we present the hierarchical structures learned by Tree VAE, while in Fig. 4 and 7, we additionally display conditional cluster generations from the leafspecific decoders. In Fig. 4, Tree VAE separates the fashion items into two subtrees, one containing shoes and bags, and the other containing the tops, which are further refined into long and short sleeves. In Fig. 5, we depict the most prevalent ground-truth topic label in each leaf. Tree VAE learns to separate technological and societal subjects and discovers semantically meaningful subtrees. In Fig. 6, Tree VAE learns to split alphabets into Indian (Odia and Bengali) and Slavic (Glagolitic and Cyrillic) subtrees, while Braille is grouped with the Indian languages due to similar circle-like structures. For Celeb A, Fig. 7 and Table 3, the resulting tree separates genders in the root split. Females (left) are further divided by hair color and hairstyle (bangs). Males (right) are further divided by smile intensity, beard, hair loss, and age. In Fig. 8 and Appendix C we show how Tree VAE can additionally be used to sample unconditional generations for all clusters simultaneously, by sampling from the root and propagating through the entire tree. The generations differ across the leaves by their cluster-specific features, whereas cluster-independent properties are retained across all generations. leaf 1 leaf 2 leaf 3 leaf 4 leaf 5 leaf 6 leaf 7 leaf 8 Figure 8: Selected unconditional generations of Celeb A. One row corresponds to one sample from the root, for which we depict the visualizations obtained from the 8 leaf-decoders. The overall face shape, skin color, and face orientation are retained among leaves from the same row, while several properties (such as make-up, beard, mustache, glasses, and hair) vary across the different leaves. 6 Conclusion In this paper, we introduced Tree VAE, a new generative method that leverages a tree-based posterior distribution of latent variables to capture the hierarchical structures present in the data. Tree VAE optimizes the balance between shared and specialized architecture, enhancing the learning and adaptation capabilities of generative models. Empirically, we showed that our model offers a substantial improvement in hierarchical clustering performance compared to the related work, while also providing a tighter lower bound to the log-likelihood of the data. We presented qualitatively how the hierarchical structures learned by Tree VAE enable a more comprehensive understanding of the data, thereby facilitating enhanced analysis, interpretation, and decision-making. Our findings highlight the versatility of the proposed approach, which we believe to hold significant potential for unsupervised representation learning, paving the way for exciting advancements in the field. Limitations & Future Work: Currently, Tree VAE uses a simple heuristic on which node to split that might not work on datasets with unbalanced clusters. Additionally, the contrastive losses on the routers encourage balanced clusters. Thus, more research is necessary to convert the heuristics to data-driven approaches. While deep latent variable models, such as VAEs, provide a framework for modeling explicit relationships through graphical structures, they often exhibit poor performance on synthetic image generation. However, more complex architectural design (Vahdat & Kautz, 2020a) or recent advancement in diffusion latent models (Rombach et al., 2021) present potential solutions to enhance image quality generation, thus striking an optimal balance between generating high-quality images and capturing meaningful representations. Acknowledgments and Disclosure of Funding We thank Thomas M. Sutter for the insightful discussions throughout the project, Jorge da Silva Gonçalves for providing interpretable visualizations of the Tree VAE model, and Gabriele Manduchi for the valuable feedback on the notation of the ELBO. LM is supported by the SDSC Ph D Fellowship #1-001568-037. MV is supported by the Swiss State Secretariat for Education, Research and Innovation (SERI) under contract number MB22.00047. AR is supported by the Stimu Loop grant #1-007811-002 and the Vontobel Foundation. Arenas, M., Barceló, P., Orth, M. A. R., & Subercaseaux, B. (2022). On computing probabilistic explanations for decision trees. In Neurips. Retrieved from http://papers.nips.cc/paper_files/ paper/2022/hash/b8963f6a0a72e686dfa98ac3e7260f73-Abstract-Conference.html Bae, J., Zhang, M. R., Ruan, M., Wang, E., Hasegawa, S., Ba, J., & Grosse, R. B. (2023). Multi-rate VAE: Train once, get the full rate-distortion curve. In The eleventh international conference on learning representations. Retrieved from https://openreview.net/forum?id= OJ8a Sj Ca MNK Basak, J., & Krishnapuram, R. (2005). Interpretable hierarchical clustering by constructing an unsupervised decision tree. IEEE Trans. Knowl. Data Eng., 17(1), 121 132. Bengio, Y., Courville, A. C., & Vincent, P. (2012). Representation learning: A review and new perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence, 35, 1798-1828. Bishop, C. M. (2006). Pattern recognition and machine learning (information science and statistics). Berlin, Heidelberg: Springer-Verlag. Blei, D. M., Jordan, M. I., Griffiths, T. L., & Tenenbaum, J. B. (2003). Hierarchical topic models and the nested chinese restaurant process. In Proceedings of the 16th international conference on neural information processing systems (p. 17 24). Cambridge, MA, USA: MIT Press. Blockeel, H., & De Raedt, L. (1998). Top-down induction of clustering trees. In Proceedings of the fifteenth international conference on machine learning (pp. 55 63). Bredell, G., Flouris, K., Chaitanya, K., Erdil, E., & Konukoglu, E. (2023). Explicitly minimizing the blur error of variational autoencoders. In The eleventh international conference on learning representations. Retrieved from https://openreview.net/forum?id=9krn Q-ue9M Breiman, L., Friedman, J. H., Olshen, R. A., & Stone, C. J. (1984). Classification and regression trees. Wadsworth. Briggs, C., Fan, Z., & András, P. (2020). Federated learning with hierarchical clustering of local updates to improve training on non-iid data. 2020 International Joint Conference on Neural Networks (IJCNN), 1-9. Campello, R. J., Moulavi, D., Zimek, A., & Sander, J. (2015). Hierarchical density estimates for data clustering, visualization, and outlier detection. ACM Transactions on Knowledge Discovery from Data (TKDD), 10(1), 1 51. Campello, R. J. G. B., Moulavi, D., Zimek, A., & Sander, J. (2015). Hierarchical density estimates for data clustering, visualization, and outlier detection. ACM Transactions on Knowledge Discovery from Data (TKDD), 10, 1 - 51. Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A simple framework for contrastive learning of visual representations. In Proceedings of the 37th international conference on machine learning. JMLR.org. Dilokthanakul, N., Mediano, P. A. M., Garnelo, M., Lee, M. C. H., Salimbeni, H., Arulkumaran, K., & Shanahan, M. (2016). Deep unsupervised clustering with Gaussian mixture variational autoencoders. (ar Xiv:1611.02648) Ester, M., Kriegel, H., Sander, J., & Xu, X. (1996). A density-based algorithm for discovering clusters in large spatial databases with noise. In E. Simoudis, J. Han, & U. M. Fayyad (Eds.), Proceedings of the second international conference on knowledge discovery and data mining (kdd-96), portland, oregon, USA (pp. 226 231). AAAI Press. Retrieved from http://www.aaai.org/Library/ KDD/1996/kdd96-037.php Falck, F., Williams, C., Danks, D., Deligiannidis, G., Yau, C., Holmes, C. C., . . . Willetts, M. (2022). A multi-resolution framework for u-nets with applications to hierarchical VAEs. In A. H. Oh, A. Agarwal, D. Belgrave, & K. Cho (Eds.), Advances in neural information processing systems. Retrieved from https://openreview.net/forum?id=PQFr7Fb Gb O Fraiman, R., Ghattas, B., & Svarc, M. (2013). Interpretable clustering using unsupervised binary trees. Adv. Data Anal. Classif., 7(2), 125 145. Retrieved from https://doi.org/10.1007/ s11634-013-0129-3 doi: 10.1007/s11634-013-0129-3 Frosst, N., & Hinton, G. E. (2017). Distilling a neural network into a soft decision tree. In T. R. Besold & O. Kutz (Eds.), Proceedings of the first international workshop on comprehensibility and explanation in AI and ML 2017 co-located with 16th international conference of the italian association for artificial intelligence (ai*ia 2017), bari, italy, november 16th and 17th, 2017 (Vol. 2071). CEUR-WS.org. Retrieved from https://ceur-ws.org/Vol-2071/CEx AIIA_2017 _paper_3.pdf Ghojogh, B., Ghodsi, A., Karray, F., & Crowley, M. (2021). Uniform manifold approximation and projection (UMAP) and its variants: Tutorial and survey. Co RR, abs/2109.02508. Retrieved from https://arxiv.org/abs/2109.02508 Goyal, P., Hu, Z., Liang, X., Wang, C., Xing, E. P., & Mellon, C. (2017). Nonparametric variational auto-encoders for hierarchical representation learning. 2017 IEEE International Conference on Computer Vision (ICCV), 5104-5112. Gregor, K., Danihelka, I., Graves, A., Rezende, D., & Wierstra, D. (2015). Draw: A recurrent neural network for image generation. In F. Bach & D. Blei (Eds.), Proceedings of the 32nd international conference on machine learning (Vol. 37, pp. 1462 1471). Lille, France: PMLR. Retrieved from https://proceedings.mlr.press/v37/gregor15.html He, J., Gong, Y., Marino, J., Mori, G., & Lehrmann, A. M. (2018). Variational autoencoders with jointly optimized latent dependency structure. In International conference on learning representations. He, J., Gong, Y., Marino, J., Mori, G., & Lehrmann, A. M. (2019). Variational autoencoders with jointly optimized latent dependency structure. In Iclr. Heller, K. A., & Ghahramani, Z. (2005). Bayesian hierarchical clustering. In Proceedings of the 22nd international conference on machine learning (pp. 297 304). Jordan, M. I., & Mitchell, T. M. (2015). Machine learning: Trends, perspectives, and prospects. Science, 349(6245), 255-260. Retrieved from https://www.science.org/doi/abs/10.1126/ science.aaa8415 doi: 10.1126/science.aaa8415 Kingma, D. P., Salimans, T., Jozefowicz, R., Chen, X., Sutskever, I., & Welling, M. (2016). Improved variational inference with inverse autoregressive flow. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, & R. Garnett (Eds.), Advances in neural information processing systems (Vol. 29). Curran Associates, Inc. Retrieved from https://proceedings.neurips.cc/paper_files/ paper/2016/file/ddeebdeefdb7e7e7a697e1c3e3d8ef54-Paper.pdf Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. ar Xiv preprint ar Xiv:1312.6114. Kingma, D. P., & Welling, M. (2019). An introduction to variational autoencoders. Found. Trends Mach. Learn., 12, 307-392. Klushyn, A., Chen, N., Kurle, R., Cseke, B., & Smagt, P. v. d. (2019). Learning hierarchical priors in vaes. In Proceedings of the 33rd international conference on neural information processing systems. Red Hook, NY, USA: Curran Associates Inc. Kobren, A., Monath, N., Krishnamurthy, A., & Mc Callum, A. (2017a). A hierarchical algorithm for extreme clustering. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, halifax, ns, canada, august 13 - 17, 2017 (pp. 255 264). ACM. Retrieved from https://doi.org/10.1145/3097983.3098079 doi: 10.1145/ 3097983.3098079 Kobren, A., Monath, N., Krishnamurthy, A., & Mc Callum, A. (2017b). A hierarchical algorithm for extreme clustering. In Proceedings of the 23rd acm sigkdd international conference on knowledge discovery and data mining (pp. 255 264). Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images (Tech. Rep. No. 0). Toronto, Ontario: University of Toronto. Lake, B. M., Salakhutdinov, R., & Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338. Retrieved from https:// www.science.org/doi/abs/10.1126/science.aab3050 doi: 10.1126/science.aab3050 Lang, K. (1995). Newsweeder: Learning to filter netnews. In Proceedings of the twelfth international conference on machine learning (p. 331-339). Laptev, D., & Buhmann, J. M. (2014). Convolutional decision trees for feature learning and segmentation. In Pattern recognition: 36th german conference, gcpr 2014, münster, germany, september 2-5, 2014, proceedings 36 (pp. 95 106). Le Cun, Y., Bengio, Y., & Hinton, G. (2015). Deep learning. Nature, 521(7553), 436 444. Le Cun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. Proc. IEEE, 86, 2278-2324. Li, Y., Hu, P., Liu, J. Z., Peng, D., Zhou, J. T., & Peng, X. (2021). Contrastive clustering. In Thirtyfifth AAAI conference on artificial intelligence, AAAI 2021, thirty-third conference on innovative applications of artificial intelligence, IAAI 2021, the eleventh symposium on educational advances in artificial intelligence, EAAI 2021, virtual event, february 2-9, 2021 (pp. 8547 8555). AAAI Press. Retrieved from https://ojs.aaai.org/index.php/AAAI/article/view/17037 Li, Y., Yang, M., Peng, D., Li, T., Huang, J., & Peng, X. (2022). Twin contrastive learning for online clustering. International Journal of Computer Vision, 130(9), 2205 2221. Liu, B., Xia, Y., & Yu, P. S. (2000). Clustering through decision tree construction. In International conference on information and knowledge management. Liu, Z., Luo, P., Wang, X., & Tang, X. (2015). Deep learning face attributes in the wild. In Proceedings of international conference on computer vision (iccv). Maaløe, L., Fraccaro, M., Liévin, V., & Winther, O. (2019). Biva: A very deep hierarchy of latent variables for generative modeling. In Neurips. Manduchi, L., Chin-Cheong, K., Michel, H., Wellmann, S., & Vogt, J. E. (2021). Deep conditional gaussian mixture model for constrained clustering. In Neural information processing systems. Mathieu, E., Le Lan, C., Maddison, C. J., Tomioka, R., & Teh, Y. W. (2019). Continuous hierarchical representations with poincaré variational auto-encoders. Advances in neural information processing systems, 32. Mattei, P.-A., & Frellsen, J. (2018). Leveraging the exact likelihood of deep latent variable models. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, & R. Garnett (Eds.), Advances in neural information processing systems (Vol. 31). Curran Associates, Inc. Retrieved from https://proceedings.neurips.cc/paper_files/paper/2018/file/ 0609154fa35b3194026346c9cac2a248-Paper.pdf Mautz, D., Plant, C., & Böhm, C. (2020). Deepect: The deep embedded cluster tree. Data Science and Engineering, 5, 419 - 432. Mc Connachie, A. W., Ibata, R. A., Martin, N., Ferguson, A. M. N., Collins, M. L. M., Gwyn, S. D. J., ... Widrow, L. M. (2018). The large-scale structure of the halo of the andromeda galaxy. ii. hierarchical structure in the pan-andromeda archaeological survey. The Astrophysical Journal, 868. Monath, N., Zaheer, M., Silva, D., Mc Callum, A., & Ahmed, A. (2019). Gradient-based hierarchical clustering using continuous representations of trees in hyperbolic space. In Proceedings of the 25th acm sigkdd international conference on knowledge discovery & data mining (pp. 714 722). Moshkovitz, M., Yang, Y., & Chaudhuri, K. (2021). Connecting interpretability and robustness in decision trees through separation. In M. Meila & T. Zhang (Eds.), Proceedings of the 38th international conference on machine learning, ICML 2021, 18-24 july 2021, virtual event (Vol. 139, pp. 7839 7849). PMLR. Retrieved from http://proceedings.mlr.press/v139/ moshkovitz21a.html Murtagh, F., & Contreras, P. (2012). Algorithms for hierarchical clustering: an overview. Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery, 2(1), 86 97. Murtagh, F., & Legendre, P. (2014). Ward s hierarchical agglomerative clustering method: Which algorithms implement ward s criterion? Journal of Classification, 31, 274-295. Nasiri, A., & Bepler, T. (2022). Unsupervised object representation learning using translation and rotation group equivariant VAE. In A. H. Oh, A. Agarwal, D. Belgrave, & K. Cho (Eds.), Advances in neural information processing systems. Retrieved from https://openreview.net/ forum?id=qmm__j Mj Ml L Neal, R. M. (2003). Density Modeling and Clustering using Dirichlet Diffusion Trees. In J. M. Bernardo et al. (Eds.), Bayesian Statistics 7 (p. 619-629). Oxford University Press. Nistér, D., & Stewénius, H. (2006). Scalable recognition with a vocabulary tree. 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR 06), 2, 2161-2168. Pace, A., Chan, A. J., & van der Schaar, M. (2022). POETREE: interpretable policy learning with adaptive decision trees. In The tenth international conference on learning representations, ICLR 2022, virtual event, april 25-29, 2022. Open Review.net. Retrieved from https://openreview .net/forum?id=AJs I-yma Kn_ Ram, P., & Gray, A. G. (2011). Density estimation trees. In Knowledge discovery and data mining. Ranganath, R., Tran, D., & Blei, D. M. (2015). Hierarchical variational models. Ar Xiv, abs/1511.02386. Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the 31st international conference on machine learning (Vol. 32, pp. 1278 1286). PMLR. Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2021). High-resolution image synthesis with latent diffusion models. 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 10674-10685. Rota Bulo, S., & Kontschieder, P. (2014). Neural decision forests for semantic image labelling. In Proceedings of the ieee conference on computer vision and pattern recognition (pp. 81 88). Shin, S.-J., Song, K., & Moon, I.-C. (2019). Hierarchically clustered representation learning. In Aaai conference on artificial intelligence. Sneath, P. H. (1957). The application of computers to taxonomy. Microbiology, 17(1), 201 226. Sneath, P. H., & Sokal, R. R. (1962). Numerical taxonomy. Nature, 193, 855 860. Sohn, K. (2016). Improved deep metric learning with multi-class n-pair loss objective. In D. D. Lee, M. Sugiyama, U. von Luxburg, I. Guyon, & R. Garnett (Eds.), Advances in neural information processing systems 29: Annual conference on neural information processing systems 2016, december 5-10, 2016, barcelona, spain (pp. 1849 1857). Retrieved from https://proceedings.neurips .cc/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html Sønderby, C. K., Raiko, T., Maaløe, L., Sønderby, S. K., & Winther, O. (2016). Ladder variational autoencoders. Advances in neural information processing systems, 29. Souza, V. F., Cicalese, F., Laber, E. S., & Molinaro, M. (2022). Decision trees with short explainable rules. In Neurips. Retrieved from http://papers.nips.cc/paper_files/paper/2022/ hash/500637d931d4feb99d5cce84af1f53ba-Abstract-Conference.html Steinbach, M. S., Karypis, G., & Kumar, V. (2000). A comparison of document clustering techniques.. Suárez, A., & Lutsko, J. F. (1999). Globally optimal fuzzy decision trees for classification and regression. IEEE Trans. Pattern Anal. Mach. Intell., 21(12), 1297 1311. Retrieved from https:// doi.org/10.1109/34.817409 doi: 10.1109/34.817409 Tanno, R., Arulkumaran, K., Alexander, D., Criminisi, A., & Nori, A. (2019). Adaptive neural trees. In K. Chaudhuri & R. Salakhutdinov (Eds.), Proceedings of the 36th international conference on machine learning (Vol. 97, pp. 6166 6175). PMLR. Retrieved from https://proceedings .mlr.press/v97/tanno19a.html Vahdat, A., & Kautz, J. (2020a). Nvae: A deep hierarchical variational autoencoder. In Proceedings of the 34th international conference on neural information processing systems. Red Hook, NY, USA: Curran Associates Inc. Vahdat, A., & Kautz, J. (2020b). Nvae: A deep hierarchical variational autoencoder. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds.), Advances in neural information processing systems (Vol. 33, pp. 19667 19679). Curran Associates, Inc. Retrieved from https://proceedings.neurips.cc/paper/2020/file/ e3b21256183cf7c2c7a66be163579d37-Paper.pdf van den Oord, A., Li, Y., & Vinyals, O. (2018). Representation learning with contrastive predictive coding. Co RR, abs/1807.03748. Retrieved from http://arxiv.org/abs/1807.03748 Vikram, S., Hoffman, M. D., & Johnson, M. J. (2018). The loracs prior for vaes: Letting the trees speak for the data. Ar Xiv, abs/1810.06891. Wan, A., Dunlap, L., Ho, D., Yin, J., Lee, S., Petryk, S., ... Gonzalez, J. E. (2021). NBDT: neuralbacked decision tree. In 9th international conference on learning representations, ICLR 2021, virtual event, austria, may 3-7, 2021. Open Review.net. Retrieved from https://openreview .net/forum?id=m CLVe Eppl NE Ward, J. H. (1963). Hierarchical grouping to optimize an objective function. Journal of the American Statistical Association, 58, 236-244. Webb, S., Goli nski, A., Zinkov, R., Narayanaswamy, S., Rainforth, T., Teh, Y. W., & Wood, F. (2017). Faithful inversion of generative models for effective amortized inference. In Neural information processing systems. Williams, C. (1999). A mcmc approach to hierarchical mixture modelling. Advances in Neural Information Processing Systems, 12. Wu, Z., Xiong, Y., Yu, S. X., & Lin, D. (2018). Unsupervised feature learning via non-parametric instance discrimination. In 2018 IEEE conference on computer vision and pattern recognition, CVPR 2018, salt lake city, ut, usa, june 18-22, 2018 (pp. 3733 3742). Computer Vision Foundation / IEEE Computer Society. Retrieved from http://openaccess.thecvf.com/content _cvpr_2018/html/Wu_Unsupervised_Feature_Learning_CVPR_2018_paper.html doi: 10.1109/CVPR.2018.00393 Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. Co RR, abs/1708.07747. Retrieved from http://arxiv.org/abs/ 1708.07747 Xiao, T. Z., & Bamler, R. (2023). Trading information between latents in hierarchical variational autoencoders. In The eleventh international conference on learning representations. Retrieved from https://openreview.net/forum?id=e Wt Mdr6y Cm L You, C., Robinson, D. P., & Vidal, R. (2015). Scalable sparse subspace clustering by orthogonal matching pursuit. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 3918-3927. Zharmagambetov, A., & Carreira-Perpiñán, M. Á. (2022). Semi-supervised learning with decision trees: Graph laplacian tree alternating optimization. In Neurips. Retrieved from http://papers.nips.cc/paper_files/paper/2022/hash/ 104f7b25495a0e40e65fb7c7eee37ed9-Abstract-Conference.html Zhou, Z.-H., & Feng, J. (2017). Deep forest: Towards an alternative to deep neural networks. In Proceedings of the twenty-sixth international joint conference on artificial intelligence, IJCAI17 (pp. 3553 3559). Retrieved from https://doi.org/10.24963/ijcai.2017/497 doi: 10.24963/ijcai.2017/497 A Evidence Lower Bound In this section, we provide a closer look at the loss function of Tree VAE. We focus on the derivations of the Kullback-Leibler divergence term of the Evidence Lower Bound and provide an interpretable factorization. Furthermore, we address the computational complexity, thus offering an in-depth understanding of the loss function, its practical implications, and the trade-offs involved in its computation. A.1 ELBO Derivations In this section, we derive the KL loss (17) of the ELBO (12), which is the Kullback Leibler divergence (KL) between the prior and the variational posterior of Tree VAE. Additionally, we give details about the underlying distributional assumptions for computing the reconstruction loss. Let us define Pl the decision path from root 0 to leaf l, L is the number of leaves, which is equal to the number of paths in T , z Pl = {zi | i Pl} the set of latent variables selected by the path Pl, the parent node of the node i as pa(i), p(cpa(i) i | zpa(i)) the probability of going from pa(i) to i. For example, if we consider the path in Fig. 2 (right) we will observe c0 = 0, c1 = 1, and c4 = 0, where ci = 0 means the model selects the left child of node i. The KL loss can be expanded using Eq. 1/3: KL (q (z Pl, Pl | x) p (z Pl, Pl)) (21) = KL q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) i Pl\{0} p(cpa(i) i | zpa(i))p(zi | zpa(i)) (22) z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) q(z0 | x) Q j Pl\{0} q(cpa(j) j | x)q(zj | zpa(j)) k Pl\{0} p(cpa(k) k | zpa(k))p(zk | zpa(k)) z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(z0 | x) z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log j Pl\{0}q(cpa(j) j | x) Q k Pl\{0} p(cpa(k) k | zpa(k)) z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log j Pl\{0} q(zj | zpa(j)) Q k Pl\{0} p(zk | zpa(k)) In the following, we will simplify each of the three terms 24, 25, and 26 separately. A.1.1 KL Root The term (24) corresponds to the KL of the root node. We can integrate out all the latent variables zi for i = 0 and all decisions ci. The first term can be then written as follows: z0 q(z0 | x)q(c0 i | z0) log q(z0 | x) z0 q(z0 | x) i {1,2} q(c0 i | z0) log q(z0 | x) z0 q(z0 | x) [q(c0 = 0 | z0) + q(c0 = 1 | z0)] log q(z0 | x) z0 q(z0 | x) log q(z0 | x) = KL (q(z0 | x) p(z0)) , (30) where q(c0 = 0 | z0) + q(c0 = 1 | z0) = 1 and KL (q(z0 | x) p(z0)) is the KL between two Gaussians, which can be computed analytically. A.1.2 KL Decisions The second term (25) corresponds to the KL of the decisions. We can pull out the product from the log, yielding KLdecisions = X z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) q(cpa(j) j | x) p(cpa(j) j | zpa(j)) j Pl\{0} q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) Let us define as Pl j all paths that go through node j, as P j (denoted as Pj in the main text for brevity) the unique path that ends in the node j, and as P>j all the possible paths that start from the node j and continue to a leaf l L. Similarly, let us define as z j all the latent embeddings that are contained in the path from the root to node j and as z>j all the latent embeddings of the nodes i > j that can be reached from node j. To factorize the above equation, we first change from a pathwise view to a nodewise view. Instead of summing over all possible leaves in the tree (P l L) and then over each contained node (P j Pl\{0}), we sum over all nodes (P j V\{0}) and then over each path that leads through the selected node (P j Pl\{0} q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) The above can be proved with the following Lemma, where we rewrite P l L 1[j Pl]. Lemma A.1. Given a binary tree T as defined in Section 2.1, composed of a set of nodes V = {0, . . . , V } and leaves L V, where Pl is the decision path from root 0 to leaf l, and z Pl = {zi | i Pl} the set of latent variables selected by the path Pl. Then it holds j Pl\{0} f(j, l, z Pl) = X z Pl 1[j Pl]f(j, l, z Pl), (34) Proof. The proof is as follows: z Pl 1[j Pl]f(j, l, z Pl) = X z Pl f(j, l, z Pl) X i Pl\{0} 1[i = j] (35) i Pl\{0} f(j, l, z Pl)1[i = j] (36) i Pl\{0} f(i, l, z Pl)1[i = j] (37) i Pl\{0} f(i, l, z Pl) X j V\{0} 1[i = j] (38) i Pl\{0} f(i, l, z Pl) (39) j Pl\{0} f(j, l, z Pl). (40) Having proven the equality, we can continue with the KL of the decisions as follows: KLdecisions = X z Pl q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) z Pl [q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) k P>j q(cpa(k) k | x)q(zk | zpa(k))] (42) z j,z>j [q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) k P>j q(cpa(k) k | x)q(zk | zpa(k))] (43) z j q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) k P>j q(cpa(k) k | x)q(zk | zpa(k)) i (44) From Eq. 41 to Eq. 42, we split the inner product into the nodes of the paths Pl j that are before and after the node j. From Eq. 42 to Eq. 43, we observe that the sum over all paths going through j can be reduced to the sum over all paths starting from j, because there is only one path to j, which is specified in the product that comes after. From Eq. 43 to Eq. 44, we observe that the sum over paths starting from j and integral over z>j concern only the terms of the second line. Observe that the term on the second line of Eq. 44 integrates out to 1 and we get KLdecisions = X z j q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) zj q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(cpa(j) j | x) p(cpa(j) j | zpa(j)) zj q(z0 | x) Y i Pl\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(zj | zpa(j)) p(zj | zpa(j)) z j q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(zj | zpa(j)) p(zj | zpa(j)) k P>j q(cpa(k) k | x)q(zk | zpa(k)) zj q(z0 | x) Y i P j\{0} q(cpa(i) i | x)q(zi | zpa(i)) log q(zj | zpa(j)) p(zj | zpa(j)) zj P(pa(j); z, c)q(cpa(j) j | x)q(zj | zpa(j)) log q(zj | zpa(j)) p(zj | zpa(j)) z