# differentiable_clustering_with_perturbed_spanning_forests__5ac969d4.pdf Differentiable Clustering with Perturbed Spanning Forests Lawrence Stewart ENS & INRIA Paris, France Francis Bach ENS & INRIA Paris, France Felipe Llinares-López Google Deep Mind Paris, France Quentin Berthet Google Deep Mind Paris, France We introduce a differentiable clustering method based on stochastic perturbations of minimum-weight spanning forests. This allows us to include clustering in end-toend trainable pipelines, with efficient gradients. We show that our method performs well even in difficult settings, such as data sets with high noise and challenging geometries. We also formulate an ad hoc loss to efficiently learn from partial clustering data using this operation. We demonstrate its performance on several data sets for supervised and semi-supervised tasks.2. 1 Introduction Clustering is one of the most classical tasks in data processing, and one of the fundamental methods in unsupervised learning (Hastie et al., 2009). In most formulations, the problem consists in partitioning a collection of n elements into k clusters, in a manner that optimizes some criterion, such as intracluster proximity, or some resemblance criterion, such as a pairwise similarity matrix. This procedure is naturally related to other tasks in machine learning, either by using these induced classes in supervised problems, or by either evaluating or looking for well-clustered representations (Caron et al., 2018; Xie et al., 2016). Its performance and flexibility on a wide range of natural dataset, that makes it a good downstream or preprocessing task, also make it a a very important candidate to learn representations in a supervised fashion. Yet, it is a fundamentally combinatorial problem, representing a discrete decision, much like many classical algorithmic methods (e.g., sorting, taking nearest neighbours, dynamic programming). For these reasons, it is extremely challenging to learn through clustering. As a function, the solution of a clustering problem is piece-wise constant with respect to its inputs (such as a similarity or distance matrix), and its gradient would therefore be zero almost everywhere. This operation is therefore naturally ill-suited to the use of gradient-based approaches to minimize an objective, which are at the center of optimization procedures for machine learning. It does not have the convenient properties of commonly used operations in end-to-end differentiable systems, such as smoothness and differentiability. Another challenge of using clustering as part of a learning pipeline is perhaps its ambiguity: even for a given notion of distance or similarity between elements, there are several valid definitions of clustering, with criteria adapted to different uses. Popular methods include k-means, whose formulation is NP-hard in general even in simple settings (Drineas et al., 2004), and relies on heuristics that depend heavily on their initialization (Arthur and Vassilvitskii, 2007; Bubeck et al., 2012). Several of them rely on proximity to a centroid, or prototype, i.e., a vector representing each cluster. These fail on challenging geometries, e.g., interlaced clusters with no linear separation. We propose a new method for differentiable clustering that efficiently addresses these difficulties. It is a principled, deterministic operation: it is based on minimum-weight spanning forests, a variant All correspondence should be addressed to lawrence.stewart@ens.fr. 2Code base: https://github.com/Lawrence MMStewart/Diff Clust_Neur IPS2023 37th Conference on Neural Information Processing Systems (Neur IPS 2023). of minimum spanning trees. We chose this primitive at the heart of our method because it can be represented as a linear program (LP), and this is particularly well-adapted to our smoothing technique. However, we use a greedy algorithm to solve it exactly and efficiently (rather than solving it as an LP or relying on an uncertain heuristic). We observe that this method, often referred to as single linkage clustering (Gower and Ross, 1969), is effective as a clustering method in several challenging settings. Further, we are able to create a differentiable version of this operation, by introducing stochastic perturbations on the similarity matrix (the cost of the LP). This proxy has several convenient properties: it approximates the original function, it is smooth (both of these are controlled by a temperature parameter), and both the perturbed optimizers and its derivatives can be efficiently estimated with Monte-Carlo estimators (see, e.g., Hazan et al., 2016; Berthet et al., 2020, and references therein). This allows us to include this operation in end-to-end differentiable machine learning pipelines, and we show that this method is both efficient and performs well at capturing the clustered aspect of natural data, in several tasks. Our work is part of an effort to include unconventional operations in model training loops based on gradient computations. These include discrete operators such as optimal transport, dynamic time-warping and other dynamic programs (Cuturi, 2013; Cuturi and Blondel, 2017; Mensch and Blondel, 2018; Blondel et al., 2020b; Vlastelica et al., 2019; Paulus et al., 2020; Sander et al., 2023; Shvetsova et al., 2023), to ease their inclusion in end-to-end differentiable pipelines that can be trained with first-order methods in fields such as computer vision, audio processing, biology, and physical simulators (Cordonnier et al., 2021; Kumar et al., 2021; Carr et al., 2021; Le Lidec et al., 2021; Baid et al., 2023; Llinares-López et al., 2023) and other optimization algorithms (Dubois-Taine et al., 2022). Other related methods, including some based on convex relaxations and on optimal transport, aim to minimize an objective involving a neural network in order to perform clustering, with centroids (Caron et al., 2020; Genevay et al., 2019; Xie et al., 2016). Most of them involve optimization in order to cluster, without explicitly optimizing through the clustering operation, allowing to learn using some supervision. Another line of research is focused on using the matrix tree-theorem to compute marginals of the Gibbs distribution over spanning trees (Koo et al., 2007; Zmigrod et al., 2021). It is related to our perturbation smoothing technique (which yields a different distribution), and can also be used to learn using tree supervision. The main difference is that it does not address the clustering problem, and that optimizing in this setting involves sophisticated linear algebra computations based on determinants, which is not as efficient and convenient to implement as our method based on sampling and greedy algorithms. There is an important conceptual difference with the methodology described by Berthet et al. (2020), and other recent works based on it that use Fenchel-Young losses (Blondel et al., 2020a): while one of the core discrete operations on which our clustering procedure is based is an LP, the cluster connectivity matrix which, importantly, has the same structure as any ground-truth label information is not. This weak supervision is a natural obstacle that we have to overcome: on the one hand, it is reasonable to expect clustering data to come in this form, i.e., only stating whether some elements belong to the same cluster or not, which is weaker than ground-truth spanning forest information would be. On the other hand, linear problems on cluster connectivity matrices, such as MAXCUT, are notoriously NP-hard (Karp, 1972), so we cannot use perturbed linear oracles over these sets (at least exactly). In order to handle these two situations together, we design a partial Fenchel-Young loss, inspired by the literature on weak supervision, that allows us to use a loss designed for spanning forest structured prediction, even though we have the less informative cluster connectivity information. In our experiments, we show that our method enables us to learn through a clustering operation, i.e., that we can find representations of the data for which clustering of the data will match with some ground-truth clustering data. We apply this to both supervised settings, including some illustrations on synthetic challenging data, and semi-supervised settings on real data, focusing on settings with a low number of labeled instances and unlabeled classes. Main Contributions. In this work, we introduce an efficient and principled technique for differentiable clustering. In summary, we make the following contributions: Our method is based on using spanning forests, and a differentiable proxy, obtained by adding stochastic perturbations on the edge costs. Our method allows us to learn through clustering: one can train a model to learn clusterable representations of the data in an online fashion. The model generating the representations is informed by gradients that are transmitted through our clustering operation. We derive a partial Fenchel-Young loss, which allows us to incorporate weak cluster information, and can be used in any weakly supervised structured prediction problem. We show that it is a powerful clustering technique, allowing us to learn meaningful clustered representations. It does not require a model to learn linearly separable representations. Our operation can be incorporated efficiently into any gradient-based ML pipeline. We demonstrate this in the context of supervised and semi-supervised cluster learning tasks. Notations. For a positive integer n, we denote by [n] the set {1, . . . , n}. We consider all graphs to be undirected. For the complete graph with n vertices Kn over nodes [n], we denote by T the set of spanning trees on Kn, i.e., subgraphs with no cycles and one connected component. For any positive integer k n we also denote by Ck the set of k-spanning forests of Kn, i.e., subgraphs with no cycles and k connected components. With a slight abuse of notation, we also refer to T and Ck for the set of adjacency matrices of these graphs. For two nodes i and j in a general graph, we write i j if the two nodes are in the same connected component. We denote by Sn the set of n n symmetric matrices. 2 Differentiable clustering We provide a differentiable operator for clustering elements based on a similarity matrix. Our method is explicit: it relies on a linear programming primitive, not on a heuristic to solve another problem (e.g., k-means). It is label blind: our solution is represented as a connectivity matrix, and is not affected by a permutation of the clusters. Finally it is geometrically flexible: it is based on single linkage, and does not rely on proximity to a centroid, or linear separability to include several elements in the same cluster (as illustrated in Section 4). In order to cluster n elements in k clusters, we define a clustering operator as a function M k( ) taking an n n symmetric similarity matrix as input (e.g., negative pairwise square distances) and outputting a cluster connectivity matrix of the same size (see Definition 1). We also introduce its differentiable proxy M k,ε( ). We use them as a method to learn through clustering. As an example, in the supervised learning setting, we consider the n elements described by features X = (X1, . . . , Xn) in some feature space X, and ground-truth clustering data as an n n matrix MΩ(with either complete or partial information, see Definition 4). A parameterized model produces a similarity matrix Σ = Σw(X) e.g., based on pairwise square distances between embeddings Φw(Xi) for some model Φw. Minimizing in w a loss so that M k(Σ = Σw(X)) MΩallows to train a model based on the clustering information (see Figure 1). We describe in this section and the following one the tools that we introduce to achieve this goal, and show in Section 4 experimental results on real data. Features Model Embeddings Similarity matrix Partial information Diff. Clustering Loss gradient Figure 1: Our pipeline, in the semi-supervised learning setting: data points are embedded by a parameterized model, which produces a similarity matrix. Partial clustering information may be available, in the form of must-link or must-not-link constraints. Our clustering and differentiable clustering operators are used, respectively for prediction and gradient computations. 2.1 Clustering with k-spanning forests In this work, the core operation is to cluster n elements, using a similarity matrix Σ Sn. Informally, pairs of elements (i, j) with high similarity Σij should be more likely to be in the same cluster than those with low similarity. The clustering is represented in the following manner. Definition 1 (Cluster connectivity matrix). Let π : [n] [k] be a clustering function, assigning elements to one of k clusters. We represent it with an n n binary matrix M (the cluster connectivity): Mij = 1 if and only if π(i) = π(j) . We denote by Bk {0, 1}n n the set of binary cluster connectivity matrices with k clusters. Using this definition, we define an operation M k( ) mapping a similarity matrix Σ to a clustering, in the form of such a membership matrix. Up to a permutation (i.e., naming the clusters), M allows to recover π entirely. It is based on a maximum spanning forest primitive A k( ), and both are defined below. We recall that a k-forest on a graph is a subgraph with no cycles consisting in k connected components, potentially single nodes (see Figure 2). The cluster connectivity matrix is sometimes referred to as the cluster coincidence matrix matrix in other literature, and the two terms can be used interchangeably. Definition 2 (Maximum k-spanning forest). Let Σ be an n n similarity matrix. We denote by A k(Σ) the adjacency matrix of the k-spanning forest with maximum similarity, defined as A k(Σ) = argmax A Ck A, Σ . This defines a mapping A k : Sn Ck. We denote by Fk the value of this maximum, i.e., Fk(Σ) = max A Ck A, Σ . Definition 3 (Spanning forest clustering). Let A be the adjacency matrix of a k-spanning forest. We denote by M (A) the connectivity matrix of the k connected components of A, i.e., Mij = 1 if and only if i j . Given an n n similarity matrix Σ, we denote by M k(Σ) Bk the clustering induced by the maximum k-spanning forest, defined by M k(Σ) = M (A k(Σ)) . This defines a mapping M k : Sn Bk, our clustering operator. Remarks. The solution of the linear program is unique almost everywhere for similarity matrices (relevant for us, as we use stochastic perturbations in learning). Both these operators can be computed using Kruskal s algorithm to find a minimum spanning tree in a graph (Kruskal, 1956). This algorithm builds the tree by iteratively adding edges in a greedy fashion, maintaining non-cyclicity. On a connected graph on n nodes, after n 1 edge additions, a spanning tree (i.e., that covers all nodes) is obtained. The greedy algorithm is optimal for this problem, which can be proved by showing that forests can be seen as the independent sets of a matroid (Cormen et al., 2022). Further, stopping the algorithm after n k edge additions yields a forest consisting of k trees (possibly singletons), together spanning the n nodes and which is optimal for the problem in Definition 2. As in several other clustering methods, we specify in ours the number of desired clusters k, but a consequence of our algorithmic choice is that one can compute the clustering operator for all k without much computational overhead, and that this number can easily be tuned by validation. Further, a common manner to run this algorithm is to keep track of the constructed connected components, therefore both A k(Σ) and M k(Σ) are actually obtained by this algorithm. The mapping M is of course many-to-one, and yields a partition of Ck into equivalence classes of k-forests that yield the same clusters. This point is actually at the center of our operators of clustering with constraints in Definition 5 and of our designed loss in Section 3.2, both below. We note that as the maximizer of a linear program, when the solution is unique, we have ΣFk(Σ) = A k(Σ). As described above, our main goal is to enable the inclusion of these operations in end-to-end differentiable pipelines (see Figure 1). We include two important aspects to our framework in order to do so, and to efficiently learn through clustering. The first one is the inclusion of constraints, enforced values coming either from partial or complete clustering information (see below, Definition 5), and the second one is the creation of a differentiable version of these operators, using stochastic perturbations (see Section 2.2). Figure 2: Method illustration, for k = 2: (left) Similarity matrix based on pairwise square distance, partial cluster connectivity information. (center) Clustering using spanning forests without partial clustering constraints. (right) Constrained clustering with partial constraints. Clustering with constraints. Given information, we also consider constrained versions of these two problems. We represent the enforced connectivity information as a matrix MΩ, defined as follows. Definition 4 (Partial cluster coincidence matrix). Let M be a cluster connectivity matrix, and Ωa subset of [n] [n], representing the set of observations. We denote by MΩthe n n matrix MΩ,ij = Mij if (i, j) Ω, MΩ,ij = otherwise . Remarks. The symbol in this definition is only used as a placeholder, an indicator of no information , and does not have a mathematical use. A common setting is where only a subset S [n] of the data has label information. In this case for any i, j S, MΩ,ij = 1 if i and j share the same label, otherwise MΩ,ij = 0 i.e. elements in the same class are clustered together and separated from elements in other classes. Definition 5 (Constrained maximum k-spanning forest). Let Σ be an n n similarity matrix. We denote by A k(Σ ; MΩ) the adjacency matrix of the k-spanning forest with maximum similarity, constrained to satisfy the connectivity constraints in MΩ. It is defined as A k(Σ ; MΩ) = argmax A Ck(MΩ) A, Σ , where for any partial clustering matrix MΩ, Ck(MΩ) is the set of k-spanning forests whose clusters agree with MΩ, i.e., Ck(MΩ) = {A Ck : M (A)ij = MΩ,ij (i, j) Ω} . For any partial connectivity matrix MΩ, this defines another mapping A ( ; MΩ) : Sn Ck. We denote by Fk(Σ ; MΩ) the value of this maximum, i.e., Fk(Σ ; MΩ) = max A Ck(MΩ) A, Σ . Similarly we define M k(Σ ; MΩ) = M (A k(Σ ; MΩ)). Remarks. We consider these operations in order to infer clusterings and spanning forests that are consistent with observed information. That is particularly important when designing a partial loss to learn from observed clustering information. Note that depending on what the set of observations Ω is, these constraints can be more or less subtle to satisfy. For certain sets of constraints Ω, when the matroid structure is preserved, we can obtain A k(Σ ; MΩ) by the usual greedy algorithm, by additionally checking that any new edge does not violate the constraints defined by Ω. This is for example the case when Ωcorresponds to exactly observing a fully clustered subset of points. 2.2 Perturbed clustering The clustering operations defined above are efficient and perform well (see Figure 2) but by their nature as discrete operators, they have a major drawback: they are piece-wise constant and as such cannot be conveniently included in end-to-end differentiable pipelines, such as those used to train models such as neural networks. To overcome this obstacle, we use a proxy for our operators, by introducing a perturbed version (Abernethy et al., 2016; Berthet et al., 2020; Paulus et al., 2020; Struminsky et al., 2021), obtained by taking the expectation of solutions with stochastic additive noise on the input. In these definitions and the following, we consider Z µ from a distribution with positive, differentiable density over Sn, and ε > 0. Definition 6. We define the perturbed maximum spanning forest as the expected maximum spanning forest under stochastic perturbation on the inputs. Formally, for a similarity matrix Σ, we have A k,ε(Σ) = E[A k(Σ + εZ)] = E argmax A Ck A, Σ + εZ , Fk,ε(Σ) = E[Fk(Σ + εZ)] . We define analogously A k,ε(Σ ; MΩ) = E[A k(Σ + εZ; MΩ)] and Fk,ε(Σ ; MΩ) = E[Fk(Σ + εZ; MΩ)], as well as clustering M k,ε(Σ) = E[M k(Σ + εZ)] and M k,ε(Σ ; MΩ) = E[M k(Σ + εZ; MΩ)]. We note that this defines operations A k,ε( ) and A k,ε( ; MΩ) mapping Σ Sn to cvx(Ck), the convex hull of Ck. These operators have several advantageous features: They are differentiable, and both their values and their derivatives can be estimated using Monte-Carlo methods, by averaging copies of A k(Σ + εZ(i)). These operators are the ones used to compute the gradient of the loss that we design to learn from clustering (see Definition 8 and Proposition 1). This is particularly convenient, as it does not require to implement a different algorithm to compute the differentiable version. Moreover, the use of parallelization in modern computing hardware makes the computational overhead almost nonexistent. These methods are part of a large literature on using perturbations in optimizers such as LP solutions (Papandreou and Yuille, 2011; Hazan et al., 2013; Gane et al., 2014; Hazan et al., 2016), including so-called Gumbel max-tricks (Gumbel, 1954; Maddison et al., 2016; Jang et al., 2017; Huijben et al., 2022; Blondel et al., 2022) Since M k,ε( ) is a differentiable operator from Sn to cvx(Bk), it is possible to use any loss function on cvx(Bk) to design a loss based on Σ and some ground-truth clustering information MΩ, such as L(Σ ; MΩ) = M k,ε(Σ) MΩ 2 2. This flexibility is one of the advantages of our method. In the following section, we introduce a loss tailored to be efficient to compute and performant in several learning tasks, that we call a partial Fenchel-Young loss. 3 Learning with differentiable clustering 3.1 Fenchel-Young losses In structured prediction, a common modus operandi is to minimize a loss between some structured ground truth response or label y Y and a score θ Rd (the latter often itself the output of a parameterized model). As an example, in logistic regression y {0, 1} and θ = x, β R. The framework of Fenchel-Young losses allows to tackle much more complex structures, such as cluster connectivity information. Definition 7 (Fenchel-Young loss Blondel et al. (2020a)). Let F be a convex function on Rd, and F its Fenchel dual on a convex set C Rd. The Fenchel-Young loss between θ Rd and y int(C) is LFY(θ; y) = F(θ) θ, y + F (y) . The Fenchel-Young (FY) loss satisfies several properties, making it useful for learning. In particular, it is nonnegative, convex in θ, handles well noisy labels, and its gradient with respect to the score can be efficiently computed, with θLFY(θ; y) = θF(θ) y (see, e.g., Blondel et al., 2020a). In the case of linear programs over a polytope C, taking F to be the so-called support function defined as F(θ) = maxy C y, C (consistent with our notation so far), the dual function F is the indicator function of C (Rockafellar, 1997), and the Fenchel-Young function is then given by LFY(θ, y) = F(θ) θ, y if y C , + otherwise . In the case of a perturbed maximum for F, see, e.g., Berthet et al. (2020), we have Fε(θ) = E[max y C y, θ + εZ ] , and y ε(θ) = E[argmax y C y, θ + εZ ] . In this setting (under mild regularity assumptions on the noise distribution and C), we have that εΩ= F ε is a strictly convex Legendre function on C and y ε(θ) = θFε(θ). This is actually equivalent to having max and argmax with a εΩ(y) regularization (see Berthet et al., 2020, Propositions 2.1 and 2.2). In this case, the FY loss is given by LFY,ε(θ; y) = Fε(θ) θ, y + εΩ(y) and its gradient by θLFY,ε(θ; y) = y ε(θ) y. One can also define it as LFY,ε(θ; y) = E[LFY(θ + εZ; y)] and it has the same gradient. In the perturbations case, it can be easily obtained by Monte-Carlo estimates of the perturbed optimizer, taking 1 B PB i=1 y (θ + εZi) y. 3.2 Partial Fenchel-Young losses As noted above, this loss is widely applicable in structured prediction. As presented here, it requires label data y of the same kind than the optimizer y . In our setting, the linear program is over spanning k-forests rather than connectivity matrix. This is no accident: linear programs over cut matrices (when k = 2) are already NP-hard (Karp, 1972). If one has access to richer data (such as ground-truth spanning forest information), one can ignore the operator M and focus only on A k,ε, Fk,ε, and the associated FY-loss. However, in most cases, clustering data can reasonably only be expected to be present as connectivity matrix. It is therefore necessary to alter the Fenchel-Young loss, and we introduce the following loss, which allows to work with partial information p P representing partial information about the unknown ground-truth y. Using this kind of inf-loss is common when dealing with partial label information (see, e.g., Cabannes et al., 2020). Definition 8 (Partial Fenchel-Young loss). Let F be a convex function, LFY the associated Fenchel Young loss, and for every p P a convex constraint subset C(p) C. LFY(θ; p) = min y C(p) LFY(θ; y) . This allows to learn from incomplete information about y. When we do not know its value, but only a subset of C(p) C to which it might belong, we can minimize the infimum of the FY losses that are consistent with the partial label information y Y(p). Proposition 1. When F is the support function of a compact convex set C, the partial Fenchel-Young loss (see Definition 8) satisfies 1. The loss is a difference of convex functions in θ given explicitly by LFY(θ; p) = F(θ) F(θ; p) , where F(θ; p) = max y C(p) y, θ , 2. The gradient with respect to θ is given by θ LFY(θ; p) = y (θ) y (θ; p) , where y (θ; p) = argmax y C(p) y, θ . 3. The perturbed partial Fenchel-Young loss given by LFY,ε(θ; p) = E[ LFY(θ+εZ; p)] satisfies θ LFY,ε(θ; p) = y ε(θ) y ε(θ; p) , where y ε(θ; p) = E[argmax y C(p) y, θ + εZ ] . Another possibility would be to define the partial loss as miny C(p) LFY,ε(θ; y), that is, the infimum of smoothed losses instead of the smoothed infimum loss LFY,ε(θ; p) defined above. However, there is no direct method to minimizing the smoothed loss with respect to y C(p). Note that we have a relationship between the two losses: Proposition 2. Letting LFY,ε(θ; y) = E[LFY,ε(θ + εZ; y)] and LFY,ε as in Definiton 8, we have LFY,ε(θ; p) min y C(p) LFY,ε(θ; y) . The proofs of the above two propositions are detailed in the Appendix. 3.3 Applications to differentiable clustering We apply this framework, as detailed in the following section and in Section 4, to clustering. This is done naturally by transposing notations and taking C = Ck, θ = Σ, y = A k, p = MΩ, C(p) = Ck(MΩ), and y (θ; p) = A k(Σ ; MΩ). In this setting the perturbed partial Fenchel-Young loss satisfies Σ LFY,ε(Σ ; MΩ) = A k,ε(Σ) A k,ε(Σ ; MΩ) . We learn representations of a data that fit with clustering information (either complete or partial). As described above, we consider settings with n elements described by their features X = X1, . . . , Xn in some feature space X, and MΩsome clustering information. Our pipeline to learn representations includes the following steps (see Figure 1) i) Embed each Xi in Rd with a parameterized model vi = Φw(Xi) Rd, with weights w Rp. ii) Construct a similarity matrix from these embeddings, e.g. Σw,ij = Φw(Xi) Φw(Xj) 2 2. iii) Stochastic optimization of the expected loss of LFY,ε(Σw,b, MΩ,b), using mini-batches of X. Details. To be more specific on each of these phases: i) We embed each Xi individually with a model Φw, using in our application neural networks and a linear model. This allows us to use learning through clustering as a way to learn representations, and to apply this model to other elements, for which we have no clustering information, or for use in other downstream tasks. ii) We focus on cases where the similarity matrix is the negative squared distances between those embeddings. This creates a connection between a model acting individually on each element, and a pairwise similarity matrix that can be used as an input for our differentiable clustering operator. This mapping, from w to Σ, has derivatives that can therefore be automatically computed by backpropagation, as it contains only conventional opereations (at least when Φw is itself a conventional model, such as commonly used neural networks). iii) We use our proposed smoothed partial Fenchel-Young (Section 3.2) as the objective to minimize between the partial information MΩand Σ. The full-batch version would be to minimize LFY,ε(Σw, MΩ) as a function of the parameters w Rp of the model. We focus instead on a mini-batch formulation for two reasons: first, stochastic optimization with mini-batches is a commonly used and efficient method for generalization in machine learning; second, it allows to handle larger-scale datasets. As a consequence, the stochastic gradients of the loss are given, for a mini-batch b, by w LFY,ε(Σw,b; MΩ,b) = wΣw Σ LFY,ε(Σw,b; MΩ,b) = wΣw A k,ε(Σw) A k,ε(Σw; MΩ,b) . The simplicity of these gradients is due to our particular choice of smoothed partial Fenchel-Young loss. They can be efficiently estimated automatically, as described in Section 2.2, which results in a doubly stochastic scheme for the loss optimization. 4 Experiments We apply our framework for learning through clustering in both a supervised and a semi-supervised setting, as illustrated in Figure 1. Formally, for large training datasets of size n, we either have access to a full cluster connectivity matrix MΩor a partial one (typically built by using partial label information, see below). We use this clustering information MΩ, from which mini-batches can be NMI: 1.00, SC: 0.11, Co L2: 0.00 NMI: 0.00, SC: 0.35, Co L2: 0.50 NMI: 0.00, SC: 0.36, Co L2: 0.50 NMI: 0.00, SC: 0.35, Co L2: 0.50 NMI: 1.00, SC: 0.38, Co L2: 0.00 NMI: 0.36, SC: 0.50, Co L2: 0.27 NMI: 0.42, SC: 0.50, Co L2: 0.24 NMI: 0.36, SC: 0.50, Co L2: 0.27 NMI: 1.00, SC: 0.80, Co L2: 0.00 NMI: 1.00, SC: 0.80, Co L2: 0.00 NMI: 1.00, SC: 0.80, Co L2: 0.00 NMI: 1.00, SC: 0.80, Co L2: 0.00 NMI: 0.93, SC: 0.40, Co L2: 0.03 NMI: 0.45, SC: 0.49, Co L2: 0.26 NMI: 0.56, SC: 0.57, Co L2: 0.31 NMI: 0.56, SC: 0.48, Co L2: 0.23 Figure 3: (left) Illustration of clustering methods on toy-data sets. (right) Using cluster information to learn a linear de-noising (bottom) of a noised signal (top). 0 1000 2000 3000 4000 5000 Number of labelled samples Clustering error kw = 3 kw = 6 0 1000 2000 3000 4000 5000 Number of labelled samples Classification error kw = 3 kw = 6 2000 4000 6000 8000 10000 Number of labelled samples Clustering error kw = 0 kw = 3 2000 4000 6000 8000 10000 Number of labelled samples Classification error kw = 3 kw = 6 Figure 4: Semi-supervised learning: Performance of a CNN trained on partially labeled MNIST data (top row), and a Res Net trained on partially labeled Cifar-10 data (bottom row). Our trained model (full line) is evaluated on clustering (left column) and compared to a model entirely trained on the classification task (dashed line). Both models are evaluated on a down-stream classification problem i.e. transfer learning, via a linear probe (right column). extracted, as supervision. We minimize our partial Fenchel-Young loss with respect to the weights of an embedding model, and evaluate these embeddings in two main manners on a test dataset: first, by evaluating the clustering accuracy (i.e. proportion of correct coefficients in the predicted cluster connectivity matrix), and second by training a shallow model on a classification task (using clusters as classes) on a holdout set, evaluating it on a test set. 4.1 Supervised learning We apply first our method to synthetic datasets - purely to provide an illustration of both our internal clustering algorithm, and of our learning procedure. In Figure 2, we show how the clustering operator that we use, based on spanning forests (i.e. single linkage), with Kruskal s algorithm is efficient on some standard synthetic examples, even when they are not linearly separable (compared to k-Means, mean-shift, Gaussian mixture-model). We also show that our method allows us to perform supervised learning based on cluster information, in a linear setting. For Xsignal consisting of n = 60 points in two dimensions consisting of data in four well-separated clusters (see Figure 2), we construct X by appending two noise dimensions, such that clustering based on pairwise square distances between Xi mixes the original clusters. We learn a linear de-noising transformation Xθ, with θ R4 2 through clustering, by minimizing our perturbed partial FY loss with SGD, using the known labels as supervision. We also show that our method is able to cluster virtually all of some classical datasets. We train a CNN (Le Net-5 Le Cun et al. (1998)) on mini-batches of size 64 using the partial Fenchel-Young loss to learn a clustering, with a batch-wise clustering precision of 0.99 for MNIST and 0.96 on Fashion MNIST. Using the same approach, we trained a Res Net (He et al., 2016) on CIFAR-10 (with some minor modifications to kernel size and stride), achieving a batch-wise clustering test precision of 0.93. The t-SNE visualization of the embeddings for the validation set are displayed in Figure 5. The experimental setups (as well as visualization of learnt clusters for MNIST), are detailed in the Appendix. Figure 5: (left) t-SNE of learnt embeddings for the CNN trained on MNIST with kw = 3 withheld classes (highlighted). (right) t-SNE of learnt embeddings for the Res Net trained on CIFAR-10. 4.2 Semi-supervised learning We show that our method is particularly useful in settings where labelled examples are scarce, even in the particularly challenging case of having no labelled examples for some classes. To this end, we conduct a series of experiments on the MNIST (Le Cun et al., 2010) and Cifar-10 (Krizhevsky et al., 2009) datasets, where we independently vary both the total number of labelled examples nℓ, as well as the number of withheld classes for which no labelled examples are present in the training set kw {0, 3, 6}. For MNIST we consider data sets with nℓ {100, 250, 500, 1000, 2000, 5000} labelled examples, and for Cifar-10 we consider nℓ {1000, 2500, 5000, 7500, 10, 000}. We train the same embedding models using our partial Fenchel-Young loss with batches of size 64. We use ε = 0.1 and B = 100 for the estimated loss gradients, and optimize weights using Adam (Kingma and Ba, 2015). In left column of Figure 4, we report the test clustering error (evaluated on mini-batches of the same size), for each of the models, and data sets, corresponding to the choice of nℓand kw. We compare the performance of each model to a baseline, consisting of the exact same architecture (plus a linear head mapping the embedding to logits), trained to minimize the cross entropy loss. In addition, we evaluate all models on a down-stream (transfer-learning) classification task, by learning a linear probe on top of the frozen model, trained on hold-out data with all classes present. The results are depicted in the right hand column of Figure 4. Further information regarding the experimental details can be found in the Appendix. We observe that learning through clustering allows to find a representation where class semantics are easily recoverable from the local topology. On Cifar-10, our proposed approach strikingly achieves a lower clustering error in the most challenging setting (nℓ= 1000 labelled examples and kw = 6 withheld classes) than the classification-based baseline in the most lenient setting (nℓ= 10, 000 labelled examples all classes observed). Importantly, this advantage is not limited to clustering metrics: learning through clustering can also lead to lower down-stream classification error, with the gap being most pronounced when few labelled examples are available. Moreover, besides this pronounced improvement in data efficiency, we found that our method is capable of clustering classes for which no labelled examples are seen during training (see Figure 5, left). Therefore, investigating potential applications of learning through clustering to zero-shot and self-supervised learning are promising avenues for future work. Acknowledgements We thank Jean-Philippe Vert for discussions relating to this work, Simon Legrand for conversations relating to the CLEPS computing cluster (Cluster pour l Expérimentation et le Prototypage Scientifique), and Matthieu Blondel for providing references regarding learning via Kirchhoff s theorem. We also thank the reviewers and Kyle Heuton, for providing helpful corrections for the camera-ready paper. We also acknowledge support from the French government under the management of the Agence Nationale de la Recherche as part of the Investissements d avenir program, reference ANR19-P3IA-0001 (PRAIRIE 3IA Institute), as well as from the European Research Council (grant SEQUOIA 724063), and a monetary don from Google. Abadi, M., Barham, P., Chen, J., Chen, Z., Davis, A., Dean, J., Devin, M., Ghemawat, S., Irving, G., Isard, M., et al. (2016). Tensorflow: a system for large-scale machine learning. In Osdi, volume 16, pages 265 283. Savannah, GA, USA. Abernethy, J., Lee, C., and Tewari, A. (2016). Perturbation techniques in online learning and optimization. Perturbations, Optimization, and Statistics, 233. Arthur, D. and Vassilvitskii, S. (2007). K-means++: the advantages of careful seeding. In Proceedings of the Annual Symposium on Discrete Algorithms, pages 1027 1035. Baid, G., Cook, D. E., Shafin, K., Yun, T., Llinares-López, F., Berthet, Q., Belyaeva, A., Töpfer, A., Wenger, A. M., Rowell, W. J., et al. (2023). Deepconsensus improves the accuracy of sequences with a gap-aware sequence transformer. Nature Biotechnology, 41(2):232 238. Berthet, Q., Blondel, M., Teboul, O., Cuturi, M., Vert, J.-P., and Bach, F. (2020). Learning with differentiable pertubed optimizers. Advances in Neural Information Processing Systems, 33:9508 9519. Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., Pedregosa, F., and Vert, J.-P. (2022). Efficient and modular implicit differentiation. Advances in Neural Information Processing Systems, 35:5230 5242. Blondel, M., Martins, A. F., and Niculae, V. (2020a). Learning with Fenchel-Young losses. The Journal of Machine Learning Research, 21(1):1314 1382. Blondel, M., Teboul, O., Berthet, Q., and Djolonga, J. (2020b). Fast differentiable sorting and ranking. In International Conference on Machine Learning, pages 950 959. Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., Vander Plas, J., Wanderman-Milne, S., and Zhang, Q. (2018). JAX: composable transformations of Python+Num Py programs. Bubeck, S., Meil a, M., and von Luxburg, U. (2012). How the initialization affects the stability of the k-means algorithm. ESAIM: Probability and Statistics, 16:436 452. Cabannes, V., Rudi, A., and Bach, F. (2020). Structured prediction with partial labelling through the infimum loss. In International Conference on Machine Learning, pages 1230 1239. Caron, M., Bojanowski, P., Joulin, A., and Douze, M. (2018). Deep clustering for unsupervised learning of visual features. In Proceedings of the European conference on computer vision (ECCV), pages 132 149. Caron, M., Misra, I., Mairal, J., Goyal, P., Bojanowski, P., and Joulin, A. (2020). Unsupervised learning of visual features by contrasting cluster assignments. Advances in neural information processing systems, 33:9912 9924. Carr, A. N., Berthet, Q., Blondel, M., Teboul, O., and Zeghidour, N. (2021). Self-supervised learning of audio representations from permutations with differentiable ranking. IEEE Signal Processing Letters, 28:708 712. Cordonnier, J.-B., Mahendran, A., Dosovitskiy, A., Weissenborn, D., Uszkoreit, J., and Unterthiner, T. (2021). Differentiable patch selection for image recognition. In Proceedings of the Conference on Computer Vision and Pattern Recognition, pages 2351 2360. Cormen, T. H., Leiserson, C. E., Rivest, R. L., and Stein, C. (2022). Introduction to Algorithms. MIT Press. Cuturi, M. (2013). Sinkhorn distances: Lightspeed computation of optimal transport. Advances in Neural Information Processing Systems, 26. Cuturi, M. and Blondel, M. (2017). Soft-DTW: a differentiable loss function for time-series. In International Conference on Machine Learning, pages 894 903. Defazio, A., Bach, F., and Lacoste-Julien, S. (2014). Saga: A fast incremental gradient method with support for non-strongly convex composite objectives. Advances in neural information processing systems, 27. Drineas, P., Frieze, A., Kannan, R., Vempala, S., and Vinay, V. (2004). Clustering large graphs via the singular value decomposition. Machine Learning, 56:9 33. Dubois-Taine, B., Bach, F., Berthet, Q., and Taylor, A. (2022). Fast stochastic composite minimization and an accelerated frank-wolfe algorithm under parallelization. In Advances in Neural Information Processing Systems. Gane, A., Hazan, T., and Jaakkola, T. (2014). Learning with maximum a-posteriori perturbation models. In Proc. of AISTATS, pages 247 256. Genevay, A., Dulac-Arnold, G., and Vert, J.-P. (2019). Differentiable deep clustering with cluster size constraints. ar Xiv preprint ar Xiv:1910.09036. Gower, J. C. and Ross, G. J. (1969). Minimum spanning trees and single linkage cluster analysis. Journal of the Royal Statistical Society: Series C (Applied Statistics), 18(1):54 64. Gumbel, E. J. (1954). Statistical Theory of Extreme Values and some Practical Applications: A Series of Lectures. Number 33. US Govt. Print. Office. Hastie, T., Tibshirani, R., Friedman, J. H., and Friedman, J. H. (2009). The elements of statistical learning: data mining, inference, and prediction, volume 2. Springer. Hazan, T., Maji, S., and Jaakkola, T. (2013). On sampling from the gibbs distribution with random maximum a-posteriori perturbations. In Advances in Neural Information Processing Systems, pages 1268 1276. Hazan, T., Papandreou, G., and Tarlow, D. (2016). Perturbations, Optimization, and Statistics. MIT Press. He, K., Zhang, X., Ren, S., and Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, pages 770 778. Huijben, I. A., Kool, W., Paulus, M. B., and Van Sloun, R. J. (2022). A review of the gumbel-max trick and its extensions for discrete stochasticity in machine learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 45(2):1353 1371. Jang, E., Gu, S., and Poole, B. (2017). Categorical reparameterization with gumbel-softmax. In International Conference on Learning Representations. Karp, R. M. (1972). Reducibility among combinatorial problems. In Proceedings of a symposium on the Complexity of Computer Computations, held March 20-22, 1972, at the IBM Thomas J. Watson Research Center, Yorktown Heights, New York, USA, pages 85 103. Kingma, D. P. and Ba, J. (2015). Adam: A method for stochastic optimization. In Bengio, Y. and Le Cun, Y., editors, 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings. Koo, T., Globerson, A., Carreras Pérez, X., and Collins, M. (2007). Structured prediction models via the matrix-tree theorem. In Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning, pages 141 150. Krizhevsky, A., Hinton, G., et al. (2009). Learning multiple layers of features from tiny images. Kruskal, J. B. (1956). On the shortest spanning subtree of a graph and the traveling salesman problem. Proceedings of the American Mathematical Society, 7(1):48 50. Kumar, A., Brazil, G., and Liu, X. (2021). Groomed-nms: Grouped mathematically differentiable nms for monocular 3d object detection. In Proceedings of the Conference on Computer Vision and Pattern Recognition, pages 8973 8983. Le Lidec, Q., Laptev, I., Schmid, C., and Carpentier, J. (2021). Differentiable rendering with perturbed optimizers. Advances in Neural Information Processing Systems, 34:20398 20409. Le Cun, Y., Bottou, L., Bengio, Y., and Haffner, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324. Le Cun, Y., Cortes, C., and Burges, C. (2010). Mnist handwritten digit database. ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist, 2. Llinares-López, F., Berthet, Q., Blondel, M., Teboul, O., and Vert, J.-P. (2023). Deep embedding and alignment of protein sequences. Nature Methods, 20(1):104 111. Maddison, C. J., Mnih, A., and Teh, Y. W. (2016). The concrete distribution: A continuous relaxation of discrete random variables. ar Xiv preprint ar Xiv:1611.00712. Mensch, A. and Blondel, M. (2018). Differentiable dynamic programming for structured prediction and attention. In Proc. of ICML. Papandreou, G. and Yuille, A. L. (2011). Perturb-and-MAP random fields: Using discrete optimization to learn and sample from energy models. In Proc. of ICCV, pages 193 200. Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al. (2019). Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32. Paulus, M., Choi, D., Tarlow, D., Krause, A., and Maddison, C. J. (2020). Gradient estimation with stochastic softmax tricks. Advances in Neural Information Processing Systems, 33:5691 5704. Rockafellar, R. T. (1997). Convex analysis, volume 11. Princeton university press. Sander, M. E., Puigcerver, J., Djolonga, J., Peyré, G., and Blondel, M. (2023). Fast, differentiable and sparse top-k: a convex analysis perspective. ar Xiv preprint ar Xiv:2302.01425. Shvetsova, N., Petersen, F., Kukleva, A., Schiele, B., and Kuehne, H. (2023). Learning by sorting: Self-supervised learning with group ordering constraints. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pages 16453 16463. Struminsky, K., Gadetsky, A., Rakitin, D., Karpushkin, D., and Vetrov, D. P. (2021). Leveraging recursive gumbel-max trick for approximate inference in combinatorial spaces. Advances in Neural Information Processing Systems, 34:10999 11011. Vlastelica, M., Paulus, A., Musil, V., Martius, G., and Rolínek, M. (2019). Differentiation of blackbox combinatorial solvers. ar Xiv preprint ar Xiv:1912.02175. Xie, J., Girshick, R., and Farhadi, A. (2016). Unsupervised deep embedding for clustering analysis. In International conference on machine learning, pages 478 487. Yang, B., Fu, X., Sidiropoulos, N. D., and Hong, M. (2017). Towards k-means-friendly spaces: Simultaneous deep learning and clustering. In International Conference on Machine Learning, pages 3861 3870. PMLR. Zmigrod, R., Vieira, T., and Cotterell, R. (2021). Efficient computation of expectations under spanning tree distributions. Transactions of the Association for Computational Linguistics, 9:675 690. 5 Broader impact This submission focuses on foundational and exploratory work, with application to general machine learning techniques. We propose methods to extend the range of operations that can be used in end-to-end differentiable systems, by allowing to include clustering as a differentiable operation in machine learning models. We do not foresee societal consequences that are specifically related to these methods, beyond those that are associated with the field in general. 6 Reproducibility and Licensing Our experiments use data sets that are already open-sourced and cited in the references. All the code written for this project and used to implement our experiments is available at: https://github.com/Lawrence MMStewart/Diff Clust_Neur IPS2023, and distributed under a Apache 2.0. license. 7 Limitations At present, our implementation of Kruskal s algorithm is incompatible with processing very large batch sizes at train time. However, as our method can use mini-batches, it is suitable for training deep learning models on large data sets. For example, we use a batch size of 64 (similar to that for standard classification models) for all experiments detailed in our submission. Other methods that focus on cluster assignment, rather than pairwise clustering, often have an n k parametrization, whose efficiency will depend on the comparison between n (batch size) and k (number of clusters). At inference time this is not the case, since gradients need not be back-propagated hence, any implementation of Kruskal s algorithm can be used such as the union-find implementation. Faster GPU compatible implementations of Kruskal s algorithm for training are certainly possible ; improvements could be made using a binary heap or a GPU parallelized merge sort. In many learning experiments the batch sizes are typically of values {32, 64, 128, 256, 512}, so our current implementation is sufficient, however, for some specific applications (such as bio-informatics experiments requiring tens of thousands of points in a batch), it would be necessary to improve upon the current implementation. This is mainly an engineering task relating to data-structures and GPU programming, and is beyond the scope of this paper. Finally, our clustering approach solves a linear program whose argmax solution is a matroid problem with a greedy single-linkage criterion. Whilst we saw that this method is effective for differentiable clustering in both the supervised and semi-supervised setting, there may well be other linkage criteria which also take the form of LP s but whom are more robust to outliers and noise in the train dataset. This line of work is outside the scope of the paper, but is closely related and could help improve our differentiable spanning forests framework. 8 Proofs of Technical Results Proof of Proposition 1. We show these properties successively 1) We have by Definition 8 LFY(θ; p) = min y C(p) LFY(θ; y) . We expand this LFY(θ; p) = min y C(p) F(θ) y, θ = F(θ) max y C(p) y, θ . As required, this implies, following the definitions given LFY(θ; p) = F(θ) F(θ; p) , where F(θ; p) = max y C(p) y, θ , 2) By linearity of derivatives and 1) above, we have θ LFY(θ; p) = θF(θ) θF(θ; p) = y (θ) y (θ; p) , where y (θ; p) = argmax y C(p) y, θ , Since the argmax of a constrained linear optimization problem is the gradient of its value. We note that this property holds almost everywhere (when the argmax is unique), and almost surely for costs with positive, continuous density, which we always assume (e.g. see the following). 3) By linearity of expectation and 1) above, we have θ LFY,ε(θ; p) = θE[F(θ + εZ) F(θ + εZ; p)] = θFε(θ) θFε(θ; p) = y ε(θ) y ε(θ; p) , using the definition y ε(θ; p) = E[argmax y C(p) y, θ + εZ ] . Proof of Proposition 2. By Jensen s inequality and the definition of the Fenchel-Young loss: LFY,ε(θ; p) = E[ min y C(p) LFY(θ + εZ; y)] min y C(p) E[LFY(θ + εZ; y)] = min y C(p) Fε(y) θ, y min y C(p) LFY,ε(θ; y) . 9 Algorithms for Spanning Forests As mentioned in Section 2, both A k(Σ) and M k(Σ) are calculated using Kruskal s algorithm (Kruskal, 1956). Our implementation of Kruskal s is tailored to our use: we first initialize both A k(Σ) and M k(Σ) as the identity matrix, and then sort the upper triangular entries of Σ. We build the maximumweight spanning forest in the usual greedy manner, using A k(Σ) to keep track of edges in the forest and M k(Σ) to check if an edge can be added without creating a cycle, updating both matrices at each step of the algorithm. Once the forest has k connected components, the algorithm terminates. This is done by keeping track of the number of edges that have been added at any time. We remark that our implementation takes the form as a single loop, with each step of the loop consisting only of matrix multiplications. For this reason it is fully compatible with auto-differentiation engines, such as JAX (Bradbury et al., 2018), Pytorch (Paszke et al., 2019) and Tensor Flow (Abadi et al., 2016), and suitable for GPU/TPU acceleration. Therefore, our implementation differs from that of the standard Kruskal s algorithm, which used a disjoint union-find data structure (and hence is not compatible with auto-differentiation frameworks). 9.1 Constrained Spanning Forests As an heuristic way to solve the constrained problem detailed in Section 5, we make the modifications below to our implementation of Kruskal s, under the assumption that MΩrepresents valid clustering information (i.e. with no contradiction): 1. Regularization (Optional) : It is possible to bias the optimization problem over spanning forests to encourage or discourage edges between some of the nodes, according to the clustering information. Before performing the sort on the upper-triangular of Σ, we add a large value to all entries of Σij where (MΩ)ij = 1, and subtract this same value from all entries of Σij where (MΩ)ij = 1. Entries Σij corresponding to where (MΩ)ij = are left unchanged. This biasing ensures that any edge between points that are constrained to be in the same cluster will always be processed before unconstrained edges. Similarly, any edge between points that are constrained to not be in the same cluster, will be processed after unconstrained edges. In most cases, i.e. when all clusters are represented in the partial information, such as when Ω= [n] [n] (full information), this is not required to solve the constrained linear program, but we have found that this regularization was helpful in practice. 2. Constraint enforcement : We ensure that adding an edge does not violate the constraint matrix. In other words, when considering adding the edge (i, j) to the existing forest, we check that none of the points in the connected component of i are forbidden from joining any of the points in the connected component of j. This is implemented using further matrix multiplications and done alongside the existing check for cycles. The exact implementation is detailed in our code base. 10 Existing Literature on Differentiable Clustering As discussed in Section 1, there exists many approaches which use clustering during gradient based learning, but these approaches typically use clustering in an offline fashion in order to assign labels to points. The following methods aim to learn through a clustering step (i.e. gradients back-propagate through clustering): Yang et al. (2017) use a bi-level optimization procedure (alternating between optimizing model weights and centroid clusters). They reported attaining 83% label-wise clustering accuracy on MNIST using a fully-connected deep network. Our method differs from this approach as it is allows for end-to-end online learning. Genevay et al. (2019) cast k-means as an optimal transport problem, and uses entropic regularization for smoothing. Reported a 85% accuracy on MNIST and 25% accuracy on CIFAR-10 with a CNN. 11 Additional Experimental Information Figure 6: Leftmost figure: Msignal, Center Figure: M 4 (Σ), Rightmost Figure: θ after training. We provide the details of the synthetic denoising experiment depicted in Figure 2 and described in Section 4.1. We create the signal data Xsignal R60 2 by sampling iid. from four isotropic Gaussians (15 points coming from each of the Gaussians) each having a standard deviation of 0.2. We randomly sample the means of the four Gaussians; for the example seen in Section 4.1 the sampled means were: 0.97627008 4.30378733 2.05526752 0.89766366 1.52690401 2.91788226 1.24825577 7.83546002 Let Σsignal be the pairwise euclidean similarity matrix corresponding to Xsignal, and furthermore let Msignal := M 4 (Σsignal) be the equivalence matrix corresponding to the signal (Msignal will be the target equivalence matrix to learn). We append an additional two noise dimensions to Xsignal in order to form the train data X R60 4, where the noise entries were sampled iid from a continuous unit uniform distribution. Similarly letting Σ be the pairwise euclidean similarity matrix corresponding to X, we calculate M 4 (Σ) = Msignal. Both the matrices Msignal and M 4 (Σ) are depicted in Figure 6 ; we remark that adding the noise dimensions leads to most points being assigned one of two clusters, and two points being isolated alone in their own clusters. We also create a validation set of equal size (in exactly the same manner as the train set), to ensure the model has not over-fitted to the train set. The goal of the experiment is to learn a linear transformation of the data that recovers Msignal i.e. a denoising, by minimizing the partial loss. There are multiple solutions to this problem, the most obvious being a transformation that removes the last two noise columns from X: 1 0 0 1 0 0 0 0 , for which Xθ = Xsignal For any θ R4 2, we define Σθ to be the pairwise similarity matrix corresponding to Xθ, and M 4 (Σθ) to its corresponding equivalence matrix. Then the problem can be summarized as: min θ R4 2 LFY,ε(Σθ, Msignal). (1) We initialized θ from a standard Normal distribution, and minimized the partial loss via stochastic gradient descent, with a learning rate of 0.01 and batch size 32. For perturbations, we took ε = 0.1 and B = 1000, where ε denotes the noise amplitude in randomized smoothing and B denotes the number of samples in the Monte-Carlo estimate. The validation clustering error converged to zero after 25 gradient batches. We verify that the θ attained from training is indeed learning to remove the noise dimensions (see Figure 6). 11.1 Supervised Differentiable Clustering Re LU & Avg Pool: window 2x2, strides = 2x2 Re LU & Avg Pool: window 2x2, strides = 2x2 Conv: features = 32, kernel = 3x3 Conv: features = 64, kernel = 3x3 Dense: d = 256 Dense: d = 256 Figure 7: (left) Architecture of the CNN, (middle) t-SNE visualization of train data embeddings, (right) t SNE visualization of validation data embeddings. As mentioned in Section 4.1, our method is able to cluster classical data sets such as MNIST and Fashion MNIST. We trained a CNN with the Le Net-5 architecture Le Cun et al. (1998) using our proposed partial loss as the objective function. The exact details of the CNN architecture are depicted in Figure 7. For this experiment and all experiments described below, we trained on a single Nvidia V100 GPU ; training the CNN with our proposed pipeline took < 15 minutes. The model was trained for 30k gradient steps on mini-batches of size 64. We used the Adam optimizer (Kingma and Ba, 2015) with learning rate η = 3 10 4, momentum parameters (β1, β2) = (0.9, 0.999), and an ℓ2 weight decay of 10 4. We validated / tested the model using the zero-one error between the true equivalence matrix and the equivalence matrix corresponding to the output of the model. We used an early stopping of 10k steps (i.e. training was stopped if the validation clustering error did not improve over 10k steps). For efficiency (and parallelization), we also computed this clustering error batch-wise with batch-size 64. As stated in Section 4.1, we attained a batch-wise clustering precision of 0.99 for MNIST and 0.96 on Fashion MNIST. The t-SNE visualizations of the embedding space of the model trained on MNIST for a collection of train and validation data points are depicted in Figure 7. It can be seen that the model has learnt a distinct cluster for each of the ten classes. In similar fashion, we trained a Res Net (He et al., 2016) on the Cifar-10 data set. The exact model architecture is similar to that of Res Net-50, but with minor modifications to the input convolutions for compatibility with the dimensions of Cifar images, and is detailed in the code base. The training procedure was identical to that of the CNN, except the model was trained for 75k steps (with early stopping), and used the standard data augmentation methods for Cifar-10, namely: a combination of four-pixel padding, random flip followed by a random crop. As mentioned in Section 4.1, the model achieved a batch-wise clustering test precision of 0.933. 11.2 Semi-Supervised Differentiable Clustering As mentioned in Section 4.2, we show that our method is particularly useful in settings where labelled examples are scarce, even in the particularly challenging case of having no labelled examples for some classes. Our approach allows a model to leverage the semantic information of unlabeled examples when trying to predict a target equivalence matrix MΩ; this is owing to the fact that the prediction of a class for a single point depends on the values of all other points in the batch, which is in general not the case for more common problems such as classification and regression. To demonstrate the performance of our approach, we assess our method on two tasks: 1. Semi-Supervised Clustering: Train a model to learn an embedding of the data which leads to a good clustering error. We can compare our methodology to that of a baseline model trained using the cross-entropy loss. This is to check that our model has leveraged information from the unlabeled data and that our partial loss is indeed leading to good clustering performance. 2. Downstream Classification: Assess the trained model s capacity to serve as a backbone in a downstream classification task (transfer learning), where its weights are frozen and a linear layer is trained on top of the backbone. We describe our data processing for both of these tasks below. 11.2.1 Data Sets In our semi-supervised learning experiments, we divided the standard MNIST and Cifar-10 train splits in the following manner: We create a balanced hold-out data set consisting of 1k images (100 images from each of the 10 classes). This hold-out data set will be used to assess the utility of the frozen clustering model on a downstream classification problem. From the remaining 59k images, we select a labeled train set of nℓpoints (detailed in Section 4.2). Our experiments also vary kw {0, 3, 6}, the number of digits to withhold all labels from. For example, if kw = 3, then the labels for the images corresponding to digits {0, 1, 2} will never appear in the labeled train data. 11.2.2 Semi-Supervised Clustering Task For each of the choices of nℓand kw, we train the architectures described in Section 11.1 using the following two approaches: 1. Ours: The model is trained on mini-batches, where half the batch is labeled data and half the batch is unlabeled data (i.e. a semi-supervised learning regime), to minimize the partial loss. 2. Baseline: The baseline model shares the same architecture as that described in Section 11.1, but with an additional dense layer with output dimension 10 (the number of classes). We train the model using mini-batches consisting of labeled points, minimizing the crossentropy loss. The training regime is fully-supervised learning (classification). The baseline backbone refers to all of the model, minus the dense output layer. Both models were trained with mini-batches of size 64, with points sampled uniformly without replacement. All hyper-parameters and optimization metrics were identical to those detailed in Section 4.1. For MNIST, we repeated training for each model with five different random seeds s [5] (and with three random seeds s [3] for Cifar), in order to report population statistics on the clustering error. 11.2.3 Transfer-Learning: Downstream Classification In this task both models are frozen, and their utility as a foundational backbone is assessed on a downstream classification task using the hold-out data set. We train a linear (a.k.a dense) layer on top of both models using the cross-entropy loss. We refer to this linear layer as the downstream head. Training this linear head is equivalent to performing multinomial logistic regression on the features of the model. To optimize the weights of the linear head we used the SAGA optimizer (Defazio et al., 2014). The results are depicted in Figure 4. It can be seen that training a CNN backbone using our approach with just 250 labels leads to better downstream classification performance than the baseline trained with 5000 labels. It is worth remarking that the baseline backbone was trained on the same objective function (cross-entropy) and task (classification) as the downstream problem, which is not the case for the backbone corresponding to our approach. This highlights how learning cluster-able embeddings and leveraging unlabeled data can be desirable for transfer learning.