# differentiable_random_partition_models__6e3da132.pdf Differentiable Random Partition Models Thomas M. Sutter , Alain Ryser , Joram Liebeskind, Julia E. Vogt Department of Computer Science ETH Zurich Partitioning a set of elements into an unknown number of mutually exclusive subsets is essential in many machine learning problems. However, assigning elements, such as samples in a dataset or neurons in a network layer, to an unknown and discrete number of subsets is inherently non-differentiable, prohibiting endto-end gradient-based optimization of parameters. We overcome this limitation by proposing a novel two-step method for inferring partitions, which allows its usage in variational inference tasks. This new approach enables reparameterized gradients with respect to the parameters of the new random partition model. Our method works by inferring the number of elements per subset and, second, by filling these subsets in a learned order. We highlight the versatility of our generalpurpose approach on three different challenging experiments: variational clustering, inference of shared and independent generative factors under weak supervision, and multitask learning. 1 Introduction Partitioning a set of elements into subsets is a classical mathematical problem that attracted much interest over the last few decades (Rota, 1964; Graham et al., 1989). A partition over a given set is a collection of non-overlapping subsets such that their union results in the original set. In machine learning (ML), partitioning a set of elements into different subsets is essential for many applications, such as clustering (Bishop and Svensen, 2004) or classification (De la Cruz-Mesía et al., 2007). Random partition models (RPM, Hartigan, 1990) define a probability distribution over the space of partitions. RPMs can explicitly leverage the relationship between elements of a set, as they do not necessarily assume i.i.d. set elements. On the other hand, most existing RPMs are intractable for large datasets (Mac Queen, 1967; Plackett, 1975; Pitman, 1996) and lack a reparameterization scheme, prohibiting their direct use in gradient-based optimization frameworks. In this work, we propose the differentiable random partition model (DRPM), a fully-differentiable relaxation for RPMs that allows reparametrizable sampling. The DRPM follows a two-stage procedure: first, we model the number of elements per subset, and second, we learn an ordering of the elements with which we fill the elements into the subsets. The DRPM enables the integration of partition models into state-of-the-art ML frameworks and learning RPMs from data using stochastic optimization. We evaluate our approach in three experiments, demonstrating the proposed DRPM s versatility and advantages. First, we apply the DRPM to a variational clustering task, highlighting how the reparametrizable sampling of partitions allows us to learn a novel kind of Variational Autoencoder (VAE, Kingma and Welling, 2014). By leveraging potential dependencies between samples in a dataset, DRPM-based clustering overcomes the simplified i.i.d. assumption of previous works, which used categorical priors (Jiang et al., 2016). In our second experiment, we demonstrate how to retrieve sets of shared and independent generative factors of paired images using the proposed DRPM. In Equal Contribution. Correspondence to {thomas.sutter,alain.ryser}@inf.ethz.ch 37th Conference on Neural Information Processing Systems (Neur IPS 2023). 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 n1 n2 n3 n4 n5 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 ! 1 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 Figure 1: Illustration of the proposed DRPM method. We first sample a permutation matrix π and a set of subset sizes n separately in two stages. We then use n and π to generate the assignment matrix Y , the matrix representation of a partition ρ. contrast to previous works (Bouchacourt et al., 2018; Hosoya, 2018; Locatello et al., 2020), which rely on strong assumptions or heuristics, the DRPM enables end-to-end inference of generative factors. Finally, we perform multitask learning (MTL) by using the DRPM as a building block in a deterministic pipeline. We show how the DRPM learns to assign subsets of network neurons to specific tasks. The DRPM can infer the subset size per task based on its difficulty, overcoming the tedious work of finding optimal loss weights (Kurin et al., 2022; Xin et al., 2022). To summarize, we introduce the DRPM, a novel differentiable and reparametrizable relaxation of RPMs. In extensive experiments, we demonstrate the versatility of the proposed method by applying the DRPM to clustering, inference of generative factors, and multitask learning. 2 Related Work Random Partition Models Previous works on RPMs include product partition models (Hartigan, 1990), species sampling models (Pitman, 1996), and model-based clustering approaches (Bishop and Svensen, 2004). Further, Lee and Sang (2022) investigate the balancedness of subset sizes of RPMs. They all require tedious manual adjustment, are non-differentiable, and are, therefore, unsuitable for modern ML pipelines. A fundamental RPM application is clustering, where the goal is to partition a given dataset into different subsets, the clusters. In contrast to many existing approaches (Yang et al., 2019; Sarfraz et al., 2019; Cai et al., 2022), we consider cluster assignments as random variables, allowing us to treat clustering from a variational perspective. Previous works in variational clustering (Jiang et al., 2016; Dilokthanakul et al., 2016; Manduchi et al., 2021) implicitly define RPMs to perform clustering. They compute partitions in a variational fashion by making i.i.d. assumptions about the samples in the dataset and imposing soft assignments of the clusters to data points during training. A problem related to set partitioning is the earth mover s distance problem (EMD, Monge, 1781; Rubner et al., 2000). However, EMD aims to assign a set s elements to different subsets based on a cost function and given subset sizes. Iterative solutions to the problem exist (Sinkhorn, 1964), and various methods have recently been proposed, e.g., for document ranking (Adams and Zemel, 2011) or permutation learning (Santa Cruz et al., 2017; Mena et al., 2018; Cuturi et al., 2019). Differentiable and Reparameterizable Discrete Distributions Following the proposition of the Gumbel-Softmax trick (GST, Jang et al., 2016; Maddison et al., 2017), interest in research around continuous relaxations for discrete distributions and non-differentiable algorithms rose. The GST enabled the reparameterization of categorical distributions and their integration into gradientbased optimization pipelines. Based on the same trick, Sutter et al. (2023) propose a differentiable formulation for the multivariate hypergeometric distribution. Multiple works on differentiable sorting procedures and permutation matrices have been proposed, e.g., Linderman et al. (2018); Prillo and Eisenschlos (2020); Petersen et al. (2021). Further, Grover et al. (2019) described the distribution over permutation matrices p(π) for a permutation matrix π using the Plackett-Luce distribution (PL, Luce, 1959; Plackett, 1975). Prillo and Eisenschlos (2020) proposed a computationally simpler variant of Grover et al. (2019). More examples of differentiable relaxations include the top-k elements selection procedure (Xie and Ermon, 2019), blackbox combinatorial solvers (Poganˇci c et al., 2019), implicit likelihood estimations (Niepert et al., 2021), and k-subset sampling (Ahmed et al., 2022). 3 Preliminaries Set Partitions A partiton ρ = (S1, . . . , SK) of a set [n] = {1, . . . , n} with n elements is a collection of K subsets Sk [n] where K is a priori unknown (Mansour and Schork, 2016). For a partition ρ to be valid, it must hold that S1 SK = [n] and k = l : Sk Sl = (1) In other words, every element i [n] has to be assigned to precisely one subset Sk. We denote the size of the k-th subset Sk as nk = |Sk|. Alternatively, we can describe a partition ρ through an assignment matrix Y = [y1, . . . , y K]T {0, 1}K n. Every row yk {0, 1}1 n is a multi-hot vector, where yki = 1 assigns element i to subset Sk. Within the scope of our work, we view a partition of a set of n elements as a special case of the urn model. Here, the urn contains marbles with n different colors, where each color corresponds to a subset in the partition. For each color, there are n marbles corresponding to the potential elements of their color/subset. To derive a partition, we sample n marbles without replacement from the urn and register the order in which we draw the colors. The color of the i-th marble then determines the subset to which element i corresponds. Furthermore, we can constrain the partition to only K subsets by taking an urn with only K different colors. Probability distribution over subset sizes The multivariate non-central hypergeometric distribution (MVHG) describes sampling without replacement and allows to skew the importance of groups with an additional importance parameter ω (Fisher, 1935; Wallenius, 1963; Chesson, 1976). The MVHG is an urn model and is described by the number of different groups K N, the number of elements in the urn of every group m = [m1, . . . , m K] NK, the total number of elements in the urn PK k=1 mk N, the number of samples to draw from the urn n N0, and the importance factor for every group ω = [ω1, . . . , ωK] RK 0+ (Johnson, 1987). Then, the probability of sampling n = {n1, . . . , n K}, where nk describes the number of elements drawn from group K is p(n; ω, m) = 1 where P0 is a normalization constant. Hence, the MVHG p(n; ω, m) allows us to model dependencies between different elements of a set since drawing one element from the urn influences the probability of drawing one of the remaining elements, creating interdependence between them. For the rest of the paper, we assume mk m : mk = n. We thus use the shorthand p(n; ω) to denote the density of the MVHG. We refer to Appendix A.1 for more details. Probability distribution over Permutation Matrices Let p(π) denote a distribution over permutation matrices π {0, 1}n n. A permutation matrix π is doubly stochastic (Marcus, 1960), meaning that its row and column vectors sum to 1. This property allows us to use π to describe an order over a set of n elements, where πij = 1 means that element j is ranked at position i in the imposed order. In this work, we assume p(π) to be parameterized by scores s Rn +, where each score si corresponds to an element i. The order given by sorting s in decreasing order corresponds to the most likely permutation in p(π; s). Sampling from p(π; s) can be achieved by resampling the scores as si = β log si + gi where gi Gumbel(0, β) for fixed scale β, and sorting them in decreasing order. Hence, resampling scores s enables the resampling of permutation matrices π. The probability over orderings p(π; s) is then given by (Thurstone, 1927; Luce, 1959; Plackett, 1975; Yellott, 1977) p(π; s) = p((π s)1 (π s)n) = (πs)1 Z (πs)2 Z (πs)1 (πs)n Z Pn 1 j=1 (πs)j (3) where π is a permutation matrix and Z = Pn i=1 si. The resulting distribution is a Plackett-Luce (PL) distribution (Luce, 1959; Plackett, 1975) if and only if the scores s are perturbed with noise drawn from Gumbel distributions with identical scales (Yellott, 1977). For more details, we refer to Appendix A.2). 4 A two-stage Approach to Random Partition Models We propose the DRPM p(Y ; ω, s), a differentiable and reparameterizable two-stage Random Partition Model (RPM). The proposed formulation separately infers the number of elements i per subset n NK 0 , where PK k=1 nk = n, and the assignment of elements to subsets Sk by inducing an order on the n elements and filling S1, ..., SK sequentially in this order. To model the order of the elements, we use a permutation matrix π = [π1, . . . , πn]T {0, 1}n n, from which we infer Y by sequentially summing up rows according to n. Note that the doubly-stochastic property of all permutation matrices π ensures that the columns of Y remain one-hot vectors, assigning every element i to precisely one of the K subsets. At the same time, the k-th row of Y corresponds to an nk-hot vector yk and therefore serves as a subset selection vector, i.e. i=νk+1 πi, where νk = such that Y = [y1, . . . , y K]T . Additionally, Figure 1 provides an illustrative example. Note that K defines the maximum number of possible subsets, and not the effective number of non-empty subsets, because we allow Sk to be the empty set (Mansour and Schork, 2016). We base the following Proposition 4.1 on the MVHG distribution p(n; ω) for the subset sizes n and the PL distribution p(π; s) for assigning the elements to subsets. However, the proposed two-stage approach to RPMs is not restricted to these two classes of probability distributions. Proposition 4.1 (Two-stage Random Partition Model). Given a probability distribution over subset sizes p(n; ω) with n NK 0 and distribution parameters ω RK + and a PL probability distribution over random orderings p(π; s) with π {0, 1}n n and distribution parameters s Rn +, the probability mass function p(Y ; ω, s) of the two-stage RPM is given by p(Y ; ω, s) = p(y1, . . . , y K; ω, s) = p(n; ω) X π ΠY p(π; s) (5) where ΠY = {π : yk = Pνk+nk i=νk+1 πi, k = 1, . . . , K}, and yk and νk as in Equation (4). In the following, we outline the proof of Proposition 4.1 and refer to Appendix B for a formal derivation. We calculate p(Y ; ω, s) as a probability of subsets p(y1, . . . , y K; ω, s), which we compute sequentially over subsets, i.e. p(y1, . . . , y K; ω, s) = p(y1; ω, s) p(y K | y