# banditpam_faster_kmedoids_clustering__88a7147a.pdf Bandit PAM++: Faster k-medoids Clustering Mo Tiwari Stanford University motiwari@stanford.edu Ryan Kang* Stanford University txryank@stanford.edu Donghyun Lee* University College London donghyun.lee.21@ucl.ac.uk Sebastian Thrun Stanford University thrun@stanford.edu Chris Piech Stanford University piech@cs.stanford.edu Ilan Shomorony University of Illinois at Urbana-Champaign ilans@illinois.edu Martin Jinye Zhang Carnegie Mellon University martinzh@andrew.cmu.edu Clustering is a fundamental task in data science with wide-ranging applications. In k-medoids clustering, cluster centers must be actual datapoints and arbitrary distance metrics may be used; these features allow for greater interpretability of the cluster centers and the clustering of exotic objects in k-medoids clustering, respectively. k-medoids clustering has recently grown in popularity due to the discovery of more efficient k-medoids algorithms. In particular, recent research has proposed Bandit PAM, a randomized k-medoids algorithm with state-of-the-art complexity and clustering accuracy. In this paper, we present Bandit PAM++, which accelerates Bandit PAM via two algorithmic improvements, and is O(k) faster than Bandit PAM in complexity and substantially faster than Bandit PAM in wall-clock runtime. First, we demonstrate that Bandit PAM has a special structure that allows the reuse of clustering information within each iteration. Second, we demonstrate that Bandit PAM has additional structure that permits the reuse of information across different iterations. These observations inspire our proposed algorithm, Bandit PAM++, which returns the same clustering solutions as Bandit PAM but often several times faster. For example, on the CIFAR10 dataset, Bandit PAM++ returns the same results as Bandit PAM but runs over 10 faster. Finally, we provide a high-performance C++ implementation of Bandit PAM++, callable from Python and R, that may be of interest to practitioners at https://github.com/motiwari/Bandit PAM. Auxiliary code to reproduce all of our experiments via a one-line script is available at https://github.com/Thrun Group/Bandit PAM_plusplus_experiments/. 1 Introduction Clustering is a critical and ubiquitous task in data science and machine learning. Clustering aims to separate a dataset X of n datapoints into k disjoint sets that form a partition of the original dataset. Intuitively, datapoints within a cluster are similar and datapoints in different clusters are dissimilar. Clustering problems and algorithms have found numerous applications in textual data [20], social network analysis [18], biology [28], and education [28]. 37th Conference on Neural Information Processing Systems (Neur IPS 2023). A common objective function used in clustering is Equation 1: j=1 min c C d(c, xj). (1) Under this loss function, the goal becomes to minimize the distance, defined by the distance function d, from each datapoint to its nearest cluster center c among the set of cluster centers C. Note that this formulation is general and does not require the datapoints to be vectors or assume a specific functional form of the distance function d. Specific choices of C, dataset X, and distance function d give rise to different clustering problems. Perhaps one of the most commonly used clustering methods is k-means clustering [17, 16]. In k-means clustering, each datapoint is a vector in Rp and the distance function d is usually taken to be squared L2 distance; there are no constraints on C other than that it is a subset of Rp. Mathematically, the k-means objective function is j=1 min c C xj c 2 2 (2) subject to |C| = k. The most common algorithm for the k-means problem is Lloyd iteration [16], which has been improved by other algorithms such as k-means++ [2]. These algorithms are widely used in practice due to their simplicity and computational efficiency. Despite widespread use in practice, k-means clustering suffers from several limitations. Perhaps its most significant restriction is the choice of d as the squared L2 distance. This choice of d is for computational efficiency as the mean of many points can be efficiently computed under squared L2 distance but prevents clustering with other distance metrics that may be preferred in other contexts [22, 8, 5]. For example, k-means is difficult to use on textual data that necessitates string edit distance [20], social network analyses using graph metrics [18], or sparse datasets (such as those found in recommendation systems [15]) that lend themselves to other distance functions. While k-means algorithms have been adapted to specific other metrics, e.g., cosine distance [23], these methods are bespoke to the metric and not readily generalizable. Another limitation of k-means clustering is that the set of cluster centers C may often be uninterpretable, as each cluster center is (generally) a linear combination of datapoints. This limitation can be especially problematic when dealing with structured data, such as parse trees in context-free grammars, where the mean of trees is not necessarily well-defined, or images in computer vision, where the mean image can appear as random noise [28, 15]. In contrast with k-means clustering, k-medoids clustering [10, 11] requires the cluster centers to be actual datapoints, i.e., requires C X. More formally, the objective is to find a set of medoids M X (versus Rp in k-means) that minimizes j=1 min m M d(m, xj) (3) subject to |M| = k. Note that there is no restriction on the distance function d. k-medoids clustering has several advantages over k-means. Crucially, the requirement that each cluster center is a datapoint leads to greater interpretability of the cluster centers because each cluster center can be inspected. Furthermore, k-medoids supports arbitrary dissimilarity measures; the distance function d in Equation (3) need not be a proper metric, i.e., may be negative, asymmetric, or violate the triangle inequality. Because k-medoids supports arbitrary dissimilarity measures, it can also be used to cluster exotic objects that are not vectors in Rp, such as trees and strings [28], without embedding them in Rp. The k-medoids clustering problem in Equation (3) is a combinatorial optimization algorithm that is NP-hard in general [25]; as such, algorithms for k-medoids clustering are restricted to heuristic solutions. A popular early heuristic solution for the k-medoids problem was the Partitioning Around Medoids (PAM) algorithm [11]; however, PAM is quadratic in dataset size n in each iteration, which is prohibitively expensive on large dataset. Improving the computational efficiency of these heuristic solutions is an active area of research. Recently, [28] proposed Bandit PAM, the first subquadratic algorithm for the k-medoids problem that matched prior state-of-the-art solutions in clustering quality. In this work, we propose Bandit PAM++, which improves the computational efficiency of Bandit PAM while maintaining the same results. We anticipate these computational improvements will be important in the era of big data, when k-medoids clustering is used on huge datasets. Contributions: We propose a new algorithm, Bandit PAM++, that is significantly more computationally efficient than PAM and Bandit PAM but returns the same clustering results with high probability. Bandit PAM++ is O(k) faster than Bandit PAM in complexity and substantially faster than Bandit PAM in actual runtime wall-clock runtime. Consequently, Bandit PAM++ is faster than prior state-of-the-art k-medoids algorithms while maintaining the same clustering quality. Bandit PAM++ is based on two observations about the structure of Bandit PAM and the k-medoids problem, described in Section 4. The first observation leads to a technique that we call Virtual Arms (VA). The second observation leads to a technique that we refer to as Permutation-Invariant Caching (PIC). We combine these techniques in Bandit PAM++ and prove (in Section 5) and experimentally validate (in Section 6) that Bandit PAM++ returns the same solution to the k-medoids problem as PAM and Bandit PAM with high probability, but is more computationally efficient. In some instances, Bandit PAM++ is over 10 faster than Bandit PAM. Additionally, we provide a highly optimized implementation of Bandit PAM++ in C++ that is callable from Python and R and may be of interest to practitioners. 2 Related Work As discussed in Section 1, global optimization of the k-medoids problem (Equation (3)) is NP-hard in general [25]. Recent work attempts to perform attempts global optimization and is able to achieve an optimality gap of 0.1% on one million datapoints, but is restricted to L2 distance and takes several hours to run on commodity hardware [24]. Because of the difficulty of global optimization of the k-medoids problem, many heuristic algorithms have been developed for the k-medoids problem that scale polynomially with the dataset size and number of clusters. The complexity of these algorithms is measured by their sample complexity, i.e., the number of pairwise distance computations that are computed; these computations have been observed to dominate runtime costs and, as such, sample complexity translates to wall-clock time via an approximately constant factor [28] (this is also consistent with our experiments in Section 6 and Appendix 3). Among the heuristic solutions for the k-medoids problem, the algorithm with the best clustering loss is Partitioning Around Medoids (PAM) [10, 11], which consists of two phases: the BUILD phase and the SWAP phase. However, the BUILD phase and each SWAP iteration of PAM perform O(kn2) distance computations, which can be impractical for large datasets or when distance computations are expensive. We provide greater details about the PAM algorithm in Section 3 because it is an important baseline against which we assess the clustering quality of new algorithms. Though PAM achieves the best clustering loss among heuristic algorithms, the era of huge data has necessitated the development of faster k-medoids algorithms in recent years. These algorithms have typically been divided into two categories: those that agree with PAM and recover the same solution to the k-medoids problem but scale quadratically in n, and those that sacrifice clustering quality for runtime improvements. In the former category, [25] proposed a deterministic algorithm called Fast PAM1, which maintains the same output as PAM but reduces the computational complexity of each SWAP iteration from O(kn2) to O(n2). However, this algorithm still scales quadratically in n in every iteration, which is prohibitively expensive on large datasets. Faster heuristic algorithms have been proposed but these usually sacrifice clustering quality; such algorithms include CLARA [11], CLARANS [21], and Fast PAM [25]. While these algorithms scale subquadratically in n, they return substantially worse solutions than PAM [28]. Other algorithms with better sample complexity, such as optimizations for Euclidean space and those based on tabu search heuristics [6] also return worse solutions. Finally, [1] attempts to minimize the number of unique pairwise distances or adaptively estimate these distances or coordinate-wise distances in specific settings [14, 3], but all these approaches sacrifice clustering quality for runtime. Recently, [28] proposed Bandit PAM, a state-of-the-art k-medoids algorithm that arrives at the same solution as PAM with high probability in O(kn log n) time. Bandit PAM borrows techniques from the multi-armed bandit literature to sample pairwise distance computations rather than compute all O(n2). In this work, we show that Bandit PAM can be made more efficient by reusing distance computations both within iterations and across iterations. We note that the use of adaptive sampling techniques and multi-armed bandits to accelerate algorithms has also had recent successes in other work, e.g., to accelerate the training of Random Forests [27], solve the Maximum Inner Product Search problem [27], and more [26]. 3 Preliminaries and Background Notation: We consider a dataset X of size n (that may contain vectors in Rp or other objects). Our goal is to find a solution to the k-medoids problem, Equation (3). We are also given a dissimilarity function d that measures the dissimilarity between two objects in X. Note that we do not assume a specific functional form of d. We use [n] to denote the set {1, . . . , n}, and a b (respectively, a b) to denote the minimum (respectively, maximum) of a and b. Partitioning Around Medoids (PAM): The original Partitioning Around Medoids (PAM) algorithm [10, 11] consists of two main phases: BUILD and SWAP. In the BUILD phase, PAM iteratively initializes each medoid in a greedy, one-by-one fashion: in each iteration, it selects the next medoid that would reduce the k-medoids clustering loss (Equation (3)) the most, given the prior choices of medoids. More precisely, given the current set of l medoids Ml = {m1, , ml}, the next point to add as a medoid is: m = arg min x X\Ml d(x, xj) min m Ml d(m , xj) (4) The output of the BUILD step is an initial set of the k medoids, around which a local search is performed by the SWAP phase. The SWAP phase involves iteratively examining all k(n k) medoidnonmedoid pairs and performs the swap that would lower the total loss the most. More precisely, with M the current set of k medoids, PAM finds the best medoid-nonmedoid pair to swap: (m , x ) = arg min (m,x) M (X\M) d(x, xj) min m M\{m} d(m , xj) (5) PAM requires O(kn2) distance computations for the k greedy searches in the BUILD step and O(kn2) distance computations for each SWAP iteration [28]. The quadratic complexity of PAM makes it prohibitively expensive to run on large datasets. Nonetheless, we describe the PAM algorithm because it has been observed to have the best clustering loss among heuristic solutions to the k-medoids problem. More recent algorithms, such as Bandit PAM [28], achieve the same clustering loss but have a significantly improved complexity of O(kn log n) in each step. Our proposed algorithm, Bandit PAM++, improves upon the computational complexity of Bandit PAM by a factor of O(k). Sequential Multi-Armed Bandits: Bandit PAM [28] improves the computational complexity of PAM by converting each step of PAM to a multi-armed bandit problem. A multi-armed bandit problem (MAB) is defined as a collection of random variables {R1, . . . , Rn}, called actions or arms. We are commonly interested in the best-arm identification problem, which is to identify the arm with the highest mean, i.e., arg maxi E[Ri], with a given probability of possible error δ. Many algorithms for this problem exist, each of which make distributional assumptions about the random variables {R1, . . . , Rn}; popular ones include the upper confidence bound (UCB) algorithm and successive elimination. For an overview of common algorithms, we refer the reader to [9]. We define a sequential multi-armed bandit problem to be an ordered sequence of multi-armed bandit problems Q = {B1, . . . , BT } where each individual multi-armed bandit problem Bt = {Rt 1, . . . , Rt n} has the same number of arms n, with respective, timestep-dependent means µt 1, . . . , µt n. At each timestep t, our goal is to determine (and take) the best action at = arg maxi E[Rt i]. Crucially, our choices of at will affect the rewards at future timesteps, i.e., the Rt i for t > t. Our definition of a sequential multi-armed bandit problem is similar to non-stationary multi-armed bandit problems, with the added restriction that the only non-stationarity in the problem comes from our previous actions. We now make a few assumptions for tractability. We assume that each Rt i is observed by sampling an element from a set S with S possible values. We refer to the values of S as the reference points, where each possible reference point is sampled with equal probability and determines the observed reward. With some abuse of notation, we write Rt i(xs) for the reward observed from arm Rt i when the latent variable from S is observed to be xs. We refer to a sequential multi-armed bandit as permutation-invariant, or as a SPIMAB (for Sequential, Permutation-Invariant Multi-Armed Bandit), if the following conditions hold: 1. For every arm i and timestep t, Rt i = f(Di, {a0, a1, . . . , at 1}) for some known function f and some random variable Di with mean µi := E[Di] = 1 S PS s=1 Di(xs), 2. There exists a common set of reference points, S, shared amongst each Di, 3. It is possible to sample each Di in O(1) time by drawing from the points in S without replacement, and 4. f is computable in O(1) time given its inputs. Intuitively, the conditions above require that at each timestep, each random variable Rt i is expressible as a known function of another random variable Di and the prior actions taken in the sequential multiarmed bandit problem. Crucially, Di does not depend on the timestep; Rt i is only permitted to depend on the timestep through the agent s previously taken actions {a0, a1, . . . , at 1}. The SPIMAB conditions imply that if E[Di] = µi is known for each i, then µt i := E[Rt i] is also computable in O(1) time for each i and t, i.e., for each arm and timestep. Bandit PAM: Bandit PAM [28] reduces the scaling with n of each step of the PAM algorithm by reformulating each step as a best-arm identification problem. In PAM, each of the k BUILD steps has complexity O(n2) and each SWAP iteration has complexity O(kn2). In contrast, the complexity of Bandit PAM is O(n log n) for each of the k BUILD steps and O(kn log n) for each SWAP iteration. Fundamentally, Bandit PAM achieves this reduction in complexity by sampling distance computations instead of using all O(n2) pairwise distances in each iteration. We note that all k BUILD steps of Bandit PAM (respectively, PAM) have the same complexity as each SWAP iteration of Bandit PAM (respectively, PAM). Since the number of SWAP iterations is usually O(k) ([28]; see also our experiments in the Appendix 3), most of Bandit PAM s runtime is spent in the SWAP iterations; this suggests improvements to Bandit PAM should focus on expediting its SWAP phase. 4 Bandit PAM++: Algorithmic Improvements to Bandit PAM In this section, we discuss two improvements to the Bandit PAM algorithm. We first show how each SWAP iteration of Bandit PAM can be improved via a technique we call Virtual Arms (VA). With this improvement, the modified algorithm can be cast as a SPIMAB. The conversion to a SPIMAB permits a second improvement via a technique we call the Permutation-Invariant Cache (PIC). Whereas the VA technique improves only the SWAP phase, the PIC technique improves both the BUILD and SWAP phases. The VA technique improves the complexity of each SWAP iteration by a factor of O(k), whereas the PIC improves the wall-clock runtime of both the BUILD and SWAP phases. 4.1 Virtual Arms (VA) As discussed in Section 3, most of the runtime of Bandit PAM is spent in the SWAP iterations. When evaluating a medoid-nonmedoid pair (m, xi) to potentially swap, Bandit PAM estimates the quantity: s=1 lm,xi(xs), (6) for each medoid m and candidate nonmedoid xi, where lm,xi(xs) = d(xi, xs) min m M\{m} d(m , xs) 0 (7) is the change in clustering loss (Equation (3)) induced on point xs for swapping medoid m with nonmedoid xi in the set of medoids M. Crucially, we will find that for a given xs, each lm,xi(xs) for m = 1, . . . , k, except possibly one, is equal. We state this observation formally in Theorem 1. General SPIMAB Term Bandit PAM++ BUILD step Bandit PAM++ SWAP step Arms, {Rt i}n j=1 Candidate points for medoids Points to swap in as medoids Reference points, S Points of dataset X Points of dataset X Timestep, t t-th medoid to be added (t k)-th swap to be performed Di(xs) Distance between xi and xs Distance between xi and xs f(Di, {a0, a1, . . . , at 1}) Equation 8 Equation 8 Table 1: Bandit PAM++ s two phases can each be cast in the SPIMAB framework. Theorem 1. Let lm,xi(xs) be the change in clustering loss induced on point xs by swapping medoid m with nonmedoid xi, given in Equation (7), with xs and xi fixed. Then the values lm,xi(xs) for m = 1, . . . , k are equal, except possibly where m is the medoid for reference point xs. Theorem 1 is proven in Appendix 2. Crucially, Theorem 1 tells us that when estimating Lm,xi in Equation (6) for fixed xi and various values of m, we can reuse a significant number of the summands across different indices m (across k medoids). We note that Theorem 1 has been observed in alternate forms, e.g., as the Fast PAM1 trick, in prior work [28]. However, to the best of our knowledge, we are the first to provide a formal statement and proof of Theorem 1 and demonstrate its use in an adaptive sampling scheme inspired by multi-armed bandits. Motivated by Theorem 1, we present the SWAP step of our algorithm, Bandit PAM++, in Algorithm 1. Bandit PAM++ uses the VA technique to improve the complexity of each SWAP iteration by a factor of O(k). We call the technique virtual arms because it uses only a constant number of distance computations to update each of the k virtual arms for each of the real arms, where a real arm corresponds to a datapoint. 4.2 Permutation-Invariant Caching (PIC): The original Bandit PAM algorithms considers each of the k(n k) medoid-nonmedoid pairs as arms. With the VA technique described in Section 4.1, Bandit PAM++ instead considers each of the n datapoints (including existing medoids) as arms in each SWAP iteration. Crucially, this implies that the BUILD phase and each SWAP iteration of Bandit PAM++ consider the same set of arms. It is this observation, induced by the VA technique, that allows us to cast Bandit PAM++ as a SPIMAB and implement a second improvement. We call this second improvement the Permutation-Invariant Cache (PIC). We formalize the reduction of Bandit PAM++ to a SPIMAB in Table 1. In the SPIMAB formulation of Bandit PAM++, the set of reference points S is the same as the set of datapoints X. Each Di is a random variable representing the distance from point xi to one of the sampled reference points and can be sampled in O(1) time without replacement. Each µi = E[Di] corresponds to the average distance from point xi to all the points in the dataset X. Each arm Rt i corresponds to the point that we would add to the set of medoids (for t k) or swap in to the set of medoids (for t > k), of which there are n at each possible timestep. Similarly, the actions {a0, . . . , ap, . . . , at} correspond to points added to the set of medoids (for t k) or swaps performed (for t > k). Equation (8) provides the functional forms of f for each of the BUILD and SWAP steps. f(Di(xs), A) = d(xi, xs) min m M d(m , xs) 0. (8) where M is a set of medoids. For the BUILD step, A is a sequence of t actions that results in a set of medoids of size t k, and, for the SWAP step, A is a set of actions that results in a set of medoids M of size k. The observation that Bandit PAM++ is a SPIMAB allows us to develop an intelligent cache design, which we call a permutation-invariant cache (PIC). We may choose a permutation π of the reference points S = X and sample distance computations to these reference points in the order of the permutation. Since we only need to sample some of the reference points, and not all of them, we do not need to compute all O(n2) pairwise distances. Crucially, we can also reuse distance computations across different steps of Bandit PAM++ to save on computational cost and runtime. Algorithm 1 Bandit PAM++ SWAP Step ( fj(Dj, {a1, . . . , at}), δ, σx, permutation π of [n] ) 1: Ssolution [n] Set of potential solutions to MAB 2: t 0 Number of reference points evaluated 3: For all (i, j) [n] [k], set ˆµi,j 0, Ci,j Initial means and CIs for all swaps 4: while t < n and |Ssolution| > 1 do 5: s π(t ) Uses PIC 6: for all i Ssolution do 7: Let c(s) and c(2)(s) be the indices of xs s closest and second closest medoids Cached 8: Compute distance to xs s closest medoid d1 := d(mc(s), xs) Cached 9: Compute distance to xs s second closest medoid d2 := d(mc(2)(s), xs) Cached 10: Compute di := d(xi, xs) Reusing xs s across calls leads to more cache hits 11: ˆµi,c(s) t ˆµi,c(s) d1+min(d2,di) t +1 Update running mean for xs s medoid 12: Ci,c(s) σi q δ ) t +1 Update confidence interval for xs s medoid 13: for all j {1, . . . , k} \ {c(s)} do 14: ˆµi,j t ˆµi,j+f(Di(xs),a1,...,ak) t +1 Update running means; does not depend on j 15: Ci,j σi q δ ) t +1 Update confidence intervals; does not depend on j 16: Ssolution {i : j s.t. ˆµi,j Ci,j mini,j(ˆµi,j + Ci,j)} Filter suboptimal arms 17: t t + 1 18: if |Ssolution| = 1 then 19: return i Ssolution and j = arg minj ˆ µi ,j 20: else 21: Compute µi,j exactly for all i Ssolution At most 3n distance computations 22: return (i , j ) = arg min(i,j):i Ssolution µi,j The full Bandit PAM++ algorithm is given in Algorithm 1. Crucially, for each candidate point xi to swap into the set of medoids on Line 6, we only perform 3 distance computations (not k) to update all k arms, each of which has a mean and confidence interval (CI), on Lines 13-15. This is permitted by Theorem 1 and the VA technique which says that k 1 virtual arms for a fixed i will get the same update. The PIC technique allows us to choose a permutation of reference points (the xs s) and reuse those xs s across the BUILD and SWAP steps; as such, many values of d(xi, xs) can be cached. We emphasize that Bandit PAM++ uses the same BUILD step as the original Bandit PAM algorithm, but with the PIC. The PIC is also used during the SWAP step of Bandit PAM++, as is the VA technique. We prove that the full Bandit PAM++ algorithm returns the same results as Bandit PAM and PAM in Section 5 and demonstrate the empirical benefits of both the PIC and VA techniques in Section 6. 5 Analysis of the Algorithm In this section, we demonstrate that, with high probability, Bandit PAM++ returns the same answer to the k-medoids clustering problem as PAM and Bandit PAM while improving the SWAP complexity of Bandit PAM by O(k) and substantially decreasing its runtime. Since the BUILD step of Bandit PAM is the same as the BUILD step of Bandit PAM++, it is sufficient to show that each SWAP step of Bandit PAM++ returns the same swap as the corresponding step of Bandit PAM (and PAM). All of the following theorems are proven in Appendix 2. First, we demonstrate that PIC does not affect the results of Bandit PAM++ in Theorem 2: Theorem 2. Let X = {x1, . . . , x S} be the reference points of Di, and let π be a random permutation of {1, . . . , S}. Then for any c S, Pc q=1 Di(xπ(pq)) has the same distribution as Pc q=1 Di(xpq), where each pq is drawn uniformly without replacement from {1, . . . , S}. Intuitively, Theorem 2 says that instead of randomly sampling new reference points at each iteration of Bandit PAM++, we may choose a fixed permutation π in advance and sample in permutation order at each step of the algorithm. This allows us to reuse computation across different steps of the algorithm. We now show that Bandit PAM++ returns the same result as Bandit PAM (and PAM) in every SWAP iteration and has the same complexity in n as Bandit PAM. First, we consider a single call to Algorithm 1. Let µi := minj [k] µi,j and let i := arg mini [n] µi be the optimal point to swap in to the set of medoids, so that the medoid to swap out is j := arg minj [k] µi ,j. For another candidate point i [n] with i = i , let i := µi µi , and for i = i , let i := min(2) j µi,j minj µi,j, where min(2) j denotes the second smallest value over the indices j. To state the following results, we will assume that, for a fixed candidate point i and a randomly sampled reference point xs, the random variable f(Di(xs), A) is σi-sub-Gaussian for some known parameter σi (which, in practice, can be estimated from the data [28]): Theorem 3. For δ = 1/kn3, with probability at least 1 2 n, Algorithm 1 returns the optimal swap to perform using a total of M distance computations, where E[M] 6n + X i [n] min 12 2 i (σi + σi )2 log kn + B, 3n . Intuitively, Theorem 3 states that with high probability, each SWAP iteration of Bandit PAM++ returns the same result as Bandit PAM and PAM. Since the BUILD step of Bandit PAM++ is the same as the BUILD step of Bandit PAM, this implies that Bandit PAM++ follows the exact same optimization trajectories as Bandit PAM and PAM over the entire course of the algorithm with high probability. We formalize this observation in Theorem 4: Theorem 4. If Bandit PAM++ is run on a dataset X with δ = 1/kn3, then it returns the same set of k medoids as PAM with probability 1 o(1). Furthermore, the total number of distance computations Mtotal required satisfies E[Mtotal] = O (n log kn) . Note on assumptions: For Theorem 3, we assumed that the data is generated in a way such that the observations f(Di(xs), A) follow a sub-Gaussian distribution. Furthermore, for Theorem 4, we assume that the i s are not all close to 0, i.e., that we are not in the degenerate arm setting where many of the swaps are equally optimal, and assume that the σi s are bounded (we formalize these assumptions in Appendix 2). These assumptions have been found to hold in many real-world datasets [28]; see Section 7 and Appendices 1.1, and 2 for more formal discussions. Additionally, we assume that both Bandit PAM and Bandit PAM++ place a hard constraint T on the maximum number of SWAP iterations that are allowed. While the limit on the maximum number of swap steps T may seem restrictive, it is not uncommon to place a maximum number of iterations on iterative algorithms. Furthermore, T has been observed empirically to be O(k) [25], consistent with our experiments in Section 6 and Appendix 3. We note that statements similar to Theorems 3 and 4 can be proven for other values of δ. We provide additional experiments to understand the effects of the hyperparameters T and δ in Appendix 3. Complexity in k: The original Bandit PAM algorithm scales as O(kc(k)n log kn), where c(k) is a problem-dependent function of k. Intuitively, c(k) governs the hardness of the problem as a function of k; as more medoids are added, the average distance from each point to its closest medoid will decrease and the arm gaps (the i s) will decrease, increasing the sample complexity in Theorem 3. With the VA technique, Bandit PAM++ removes the explicit factor of k: each SWAP iteration has complexity O(c(k)n log kn). The implicit dependence on k may still enter through the term c(k) for a fixed dataset, which we observe in our experiments in Section 6. 6 Empirical Results Setup: Bandit PAM++ consists of two improvements upon the original Bandit PAM algorithm: the VA and PIC techniques. We measure the gains of each technique by presenting an ablation study in which we compare the original Bandit PAM algorithm (BP), Bandit PAM with only the VA technique (BP+VA), Bandit PAM with only the PIC (BP+PIC), and Bandit PAM with both the VA and PIC techniques (BP++, the final Bandit PAM++ algorithm). First, we demonstrate that all algorithms achieve the same clustering solution and loss as Bandit PAM across a variety of datasets and dataset sizes. In particular, this implies that Bandit PAM++ matches Dataset Size (n): 10,000 15,000 20,000 25,000 30,000 MNIST (L2, k = 10) 1.00 1.00 1.00 1.00 1.00 CIFAR10 (L1, k = 10) 1.00 1.00 1.00 1.00 1.00 20 Newsgroups (cosine, k = 5) 1.00 1.00 1.00 1.00 1.00 Table 2: Clustering loss of Bandit PAM++, normalized to clustering loss of Bandit PAM, across a variety of datasets, metrics, and dataset sizes. In all scenarios, Bandit PAM++ matches the loss of Bandit PAM (in fact, returns the exact same solution). prior state-of-the-art in clustering quality. Next, we investigate the scaling of all four algorithms in both n and k across a variety of datasets and metrics. We present our results in both sample-complexity and wall-clock runtime. Bandit PAM++ outperforms Bandit PAM by up to 10 . Furthermore, our results demonstrate that each of the VA and PIC techniques improves the runtime of the original Bandit PAM algorithm. For an experiment on a dataset of size n, we sampled n datapoints from the original dataset with replacement. In all experiments using the PIC technique, we allowed the algorithm to store up to 1, 000 distance computations per point. For the wall-clock runtime and sample complexity metrics, we divide the result of each experiment by the number of swap iterations +1, where the +1 accounts for the complexity of the BUILD step. Datasets: We conduct experiments on several public, real-world datasets to evaluate Bandit PAM++ s performance: the MNIST dataset, the CIFAR10 dataset, and the 20 Newsgroups dataset. The MNIST dataset [13] contains 70,000 black-and-white images of handwritten digits. The CIFAR10 dataset [12] comprises 60,000 images, where each image consists of 32 32 pixels and each pixel has 3 colors. The 20 Newsgroups dataset [19] consist of approximately 18,000 posts on 20 topics split in two subsets: train and test. We used a fixed subsample of 10,000 training posts and embedding them into 385-dimensional vectors using a sentence transformer from Hugging Face [7]. We use the L2, L1, and cosine distances across the MNIST, CIFAR10, and 20 Newsgroups datasets, respectively. 6.1 Clustering/loss quality First, we assess the solution quality all four algorithms across various datasets, metrics, and dataset sizes. Table 2 shows the relative losses of Bandit PAM++ with respect to the loss of Bandit PAM; the results for BP+PIC and BP+VA are identical and omitted for clarity. All four algorithms return identical solutions; this demonstrates that neither the VA nor the PIC technique affect solution quality. In particular, this implies that Bandit PAM++ matches the prior state-of-the-art algorithms, Bandit PAM and PAM, in clustering quality. 6.2 Scaling with k Figure 1 compares the wall-clock runtime scaling with k of BP, BP+PIC, BP+VA, and BP++ on same datasets as Figure 1. Across all data subset sizes, metrics, and values of k, BP++ outperforms each of BP+VA and BP+PIC, both of which in turn outperform BP. As k increases, the performance gap between algorithms using the VA technique and the other algorithms increases. For example, on the CIFAR10 dataset with k = 15, Bandit PAM++ is over 10 faster than Bandit PAM. This provides empirical evidence for our claims in Section 5 that the VA technique improves the scaling of the Bandit PAM algorithm with k. We provide similar experiments that demonstrate the scaling with n of Bandit PAM++ and each of the baseline algorithms in Appendix 3. The results are qualitiatively similar to those shown here; in particular, Bandit PAM++ outperforms BP+PIC and BP+VA, which both outperform the original Bandit PAM algorithm. 7 Conclusions and Limitations We proposed Bandit PAM++, an improvement upon Bandit PAM that produces state-of-the-art results for the k-medoids problem. Bandit PAM++ improves upon Bandit PAM using the Virtual Arms 6 8 10 12 14 Number of medoids (k) Time per step (s) 1e2 MNIST, L2, n = 20000 BP++ BP+PIC BP+VA BP 6 8 10 12 14 Number of medoids (k) Time per step (s) 1e3 CIFAR10, L1, n = 20000 BP++ BP+PIC BP+VA BP 6 8 10 12 14 Number of medoids (k) Time per step (s) 1e2 20 Newsgroups, Cosine, n = 10000 BP++ BP+PIC BP+VA BP Figure 1: Average wall-clock runtime versus k for various dataset sizes, metrics, and subsample sizes n BP++ outperforms BP+PIC and BP+VA, both of which outperform BP. Negligible error bars are omitted for clarity. technique, which improves the complexity of each SWAP iteration by O(k), and the Permutation Invariant Cache, which allows the reuse of computation across different phases of the algorithm. We prove that Bandit PAM++ returns the same results as Bandit PAM (and therefore PAM) with high probability. Furthermore, our experimental evidence demonstrates the superiority of Bandit PAM++ over baseline algorithms; across a variety of datasets, Bandit PAM++ is up to 10 faster than prior state-of-the-art while returning the same results. While the assumptions of Bandit PAM and Bandit PAM++ are likely to hold in many practical scenarios, it is important to acknowledge that these assumptions can impose limitations on our approach. Specifically, when numerous arm gaps are narrow, Bandit PAM++ might employ a naïve and fall back to brute force computation. Similarly, if the distributional assumptions on arm returns are violated, the complexity of Bandit PAM++ may be no better than PAM. We discuss these settings in greater detail in Appendix 1.1. Acknowledgements We would like to thank the anonymous Reviewers, PCs, and ACs for their helpful feedback on our paper. Mo Tiwari was supported by a Stanford Interdisciplinary Graduate Fellowship (SIGF) and a Standard Data Science Scholarship. The work of Ilan Shomorony was supported in part by the National Science Foundation (NSF) under grant CCF-2046991. Martin Zhang is supported by NIH R01 MH115676. [1] Amin Aghaee, Mehrdad Ghadiri, and Mahdieh Soleymani Baghshah. Active distance-based clustering using k-medoids. Pacific-Asia Conference on Knowledge Discovery and Data Mining, 9651:253 264, 2016. [2] David Arthur and Sergei Vassilvitskii. K-means++ the advantages of careful seeding. In Proceedings of the eighteenth annual ACM-SIAM symposium on Discrete algorithms, pages 1027 1035, 2007. [3] Vivek Bagaria, Tavor Z. Baharav, Govinda M. Kamath, and David Tse. Bandit-based monte carlo optimization for nearest neighbors. In Advances in Neural Information Processing Systems, pages 3650 3659, 2019. [4] Vivek Bagaria, Govinda M. Kamath, Vasilis Ntranos, Martin J. Zhang, and David Tse. Medoids in almost-linear time via multi-armed bandits. In International Conference on Artificial Intelligence and Statistics, pages 500 509, 2018. [5] Paul S. Bradley, Olvi L. Mangasarian, and W. N. Street. Clustering via concave minimization. In Advances in Neural Information Processing Systems, pages 368 374, 1997. [6] Vladimir Estivill-Castro and Michael E. Houle. Robust distance-based clustering with applications to spatial data mining. Algorithmica, 30(2):216 242, 2001. [7] Hugging Face. all-minilm-l6-v2 model card. https://huggingface.co/ sentence-transformers/all-Mini LM-L6-v2. Accessed: 2023-10-28. [8] Anil K. Jain and Richard C. Dubes. Algorithms for clustering data. Prentice-Hall, 1988. [9] Kevin Jamieson and Robert Nowak. Best-arm identification algorithms for multi-armed bandits in the fixed confidence setting. In Annual Conference on Information Sciences and Systems, pages 1 6, 2014. [10] Leonard Kaufman and Peter J. Rousseeuw. Clustering by means of medoids. Statistical Data Analysis based on the L1 Norm and Related Methods, pages 405 416, 1987. [11] Leonard Kaufman and Peter J. Rousseeuw. Partitioning around medoids (program pam). Finding groups in data: an introduction to cluster analysis, pages 68 125, 1990. [12] Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009. [13] Yann Le Cun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. [14] Daniel Le Jeune, Richard G. Baraniuk, and Reinhard Heckel. Adaptive estimation for approximate k-nearest-neighbor computations. In International Conference on Artificial Intelligence and Statistics, 2019. [15] Jure Leskovec, Anand Rajaraman, and Jeffrey D. Ullman. Mining of massive data sets. Cambridge university press, 2020. [16] Stuart Lloyd. Least squares quantization in pcm. IEEE transactions on information theory, 28(2):129 137, 1982. [17] James Mac Queen. Some methods for classification and analysis of multivariate observations. In Berkeley Symposium on Mathematical Statistics and Probability, volume 1, pages 281 297, 1967. [18] Nina Mishra, Robert Schreiber, Isabelle Stanton, and Robert E. Tarjan. Clustering social networks. In International Workshop on Algorithms and Models for the Web-Graph, pages 56 67. Springer, 2007. [19] Tom Mitchell. Twenty Newsgroups. UCI Machine Learning Repository, 1999. DOI: https://doi.org/10.24432/C5C323. [20] Gonzalo Navarro. A guided tour to approximate string matching. ACM computing surveys, 33(1):31 88, 2001. [21] Raymond T. Ng and Jiawei Han. Clarans: A method for clustering objects for spatial data mining. IEEE transactions on knowledge and data engineering, 14(5):1003 1016, 2002. [22] Michael L Overton. A quadratically convergent method for minimizing a sum of euclidean norms. Mathematical Programming, 27(1):34 63, 1983. [23] Rameshwar Pratap, Anup Deshmukh, Pratheeksha Nair, and Tarun Dutt. A faster sampling algorithm for spherical k-means. In Asian Conference on Machine Learning, pages 343 358. PMLR, 2018. [24] Jiayang Ren, Kaixun Hua, and Yankai Cao. Global optimal k-medoids clustering of one million samples. Advances in Neural Information Processing Systems, 35:982 994, 2022. [25] Erich Schubert and Peter J Rousseeuw. Faster k-medoids clustering: improving the pam, clara, and clarans algorithms. In International Conference on Similarity Search and Applications, pages 171 187. Springer, 2019. [26] Mo Tiwari. Accelerating machine learning algorithms with adaptive sampling. ar Xiv preprint ar Xiv:2309.14221, 2023. [27] Mo Tiwari, Ryan Kang, Jaeyong Lee, Chris Piech, Ilan Shomorony, Sebastian Thrun, and Martin J Zhang. Mabsplit: Faster forest training using multi-armed bandits. Advances in Neural Information Processing Systems, 35:1223 1237, 2022. [28] Mo Tiwari, Martin Jinye Zhang, James Mayclin, Sebastian Thrun, Chris Piech, and Ilan Shomorony. Banditpam: Almost linear time k-medoids clustering via multi-armed bandits. In Advances in Neural Information Processing Systems, 2020.