# neural_clustering_processes__e180aaee.pdf Neural Clustering Processes Ari Pakman 1 Yueqi Wang 1 2 Catalin Mitelut 1 Jin Hyung Lee 1 Liam Paninski 1 Probabilistic clustering models (or equivalently, mixture models) are basic building blocks in countless statistical models and involve latent random variables over discrete spaces. For these models, posterior inference methods can be inaccurate and/or very slow. In this work we introduce deep network architectures trained with labeled samples from any generative model of clustered datasets. At test time, the networks generate approximate posterior samples of cluster labels for any new dataset of arbitrary size. We develop two complementary approaches to this task, requiring either O(N) or O(K) network forward passes per dataset, where N is the dataset size and K the number of clusters. Unlike previous approaches, our methods sample the labels of all the data points from a well-defined posterior, and can learn nonparametric Bayesian posteriors since they do not limit the number of mixture components. As a scientific application, we present a novel approach to neural spike sorting for high-density multielectrode arrays. 1. Introduction Probabilistic clustering models (or equivalently, mixture models) are a staple of statistical modelling in which a discrete latent variable is introduced for each observation, indicating its mixture component identity. Popular inference methods in these models fall into two main classes. When exploring the full posterior is crucial (e.g. there is irreducible uncertainty about the latent structure or many separate local optima exist), the method of choice is Markov Chain Monte Carlo (MCMC) (Neal, 2000; Jain & Neal, 2004). This method is asymptotically accurate but timeconsuming, with convergence that is difficult to assess. Models whose likelihood and prior are non-conjugate are par- 1Columbia University 2Now at Google. Correspondence to: Ari Pakman . Proceedings of the 37 th International Conference on Machine Learning, Online, PMLR 119, 2020. Copyright 2020 by the author(s). ticularly challenging, since in general in these cases the model parameters cannot be marginalized and must be kept as part of the state of the Markov chain. Alternatively, variational methods (Blei & Jordan, 2004; Kurihara et al., 2007; Hughes et al., 2015) are typically much faster but do not come with accuracy guarantees. As a third alternative, in recent years there has been steady progress on amortized inference methods, and such is the spirit of this work. Concretely, we propose novel techniques to perform amortized approximate posterior inference over discrete latent variables in mixture models. We consider two possible product expansions of the mixture posteriors, and in each expansion we use neural networks to express conditional factors in terms of fixed-dimensional, distributed representations that respect the permutation symmetries imposed by the discrete variables. A major advantage of our approach, compared to previous approaches to amortized clustering, is its ability to handle an arbitrary number of clusters from a well defined posterior. This makes the methods a natural choice for nonparametric Bayesian models, such as Dirichlet process mixture models (DPMM), and their extensions. Moreover, the methods can be applied to both conjugate and non-conjugate models. The term amortization refers to the process of investing computational resources to train a model that is later used for very fast posterior inference (Gershman & Goodman, 2014). Concretely, in a model with observations x and latent variables z, the amortized approach learns a parametrized function qθ(z|x) that approximates p(z|x) for any x; learning the parameters θ may be challenging, but once θ is in hand evaluating qθ(z|x) for new data x is fast. The amortized inference literature can be coarsely divided into two approaches. On one side, the variational autoencoder (VAE) approach (Kingma & Welling, 2013), with roots in the wake-sleep algorithm (Hinton et al., 1995), learns qθ(z|x) along with the generative model pφ(x|z). Here p(z) is usually a known simple distribution. Our work corresponds to the alternative case: a generative model p(x, z) is postulated, and posterior inference is the main focus of the learning phase. Amortized methods in this case usually involve a degree of specialization to the particular generative model of interest. Examples include methods developed for Bayesian networks (Stuhlm uller et al., 2013), Neural Clustering Processes sequential Monte Carlo (Paige & Wood, 2016), probabilistic programming (Ritchie et al., 2016; Le et al., 2016), neural decoding (Parthasarathy et al., 2017) and particle tracking (Sun & Paninski, 2018). Our work is specialized to the case where the latent variables are discrete and their range is not fixed beforehand. After training a neural architecture using labeled samples from a particular generative model, we can obtain independent, parallelizable, approximate posterior samples of the discrete variables for any new set of observations of arbitrary size, with no need for expensive MCMC steps. These samples can be used (i) to approximate expectations, (ii) as high quality importance samples, or (iii) as independent Metropolis-Hastings proposals. In Section 2 we introduce generative mixture models and present two distinct expansions of the posterior distribution. In Section 3 and Section 4 we present neural architectures to model the factors of each expansion, along with their objective functions. In Section 5 we present two simple examples to illustrate the methods. In Section 6 we review related works. In Section 7 we discuss quantitative evaluations of the new methods. We close in Section 8 with a neuroscientific application to spike sorting for high-density multielectrode probes. The Supplementary Material (SM) contains details on the architectures, the spike-sorting application, and an extension of these ideas to particle tracking.1 2. Generative Mixture Models We start by presenting mixture models from the perspective of probabilistic models for clustering (Mc Lachlan & Basford, 1988). The latter introduce random variables ci denoting the cluster number to which the data point xi is assigned, and assume a generating process of the form α1, α2 p(α) c1 . . . c N p(c1, . . . , c N|α1) (1) µ1 . . . µK|c1:N p(µ1, . . . µK|α2) xi p(xi|µci) i = 1 . . . N. Here α1, α2 are hyperparameters. The number of clusters K is a random variable, indicating the number of distinct values among the sampled ci s, and µk denotes a parameter vector controlling the distribution of the k-th cluster (e.g., µk could include both the mean and covariance of a Gaussian mixture component). We assume that the priors p(c1:N|α1) and p(µ1:K|α2) are exchangeable, p(c1, . . . , c N|α1) = p(cσ1, . . . , cσN |α1) , 1An early version appeared in (Pakman & Paninski, 2018; Wang et al., 2019). Similar methods were applied to amortized permutations in (Pakman et al., 2019). where {σi} is an arbitrary permutation of the indices, and similarly for p(µ1:K|α2). Our interest in this work is in cases where K can take any value K N, such as the Chinese Restaurant Process (CRP) or its Pitman-Yor generalization (see Rodriguez & Mueller (2013) for a review). Of course, our methods will also work for models with K < B with fixed B, such as Mixtures of Finite Mixtures (Miller & Harrison, 2018). Instead of representing configurations using N labels ci, an alternative is obtained using K sets of indices: sk = (sk,1, . . . , sk,Nk) k = 1 . . . K , (2) where k, i, csk,i = k. For example, the labels c1:6 = (1, 1, 2, 1, 2, 1) are equivalent to s1 = (1, 2, 4, 6), s2 = (3, 5). Note that cluster k has size Nk and N = PK k=1 Nk. Given N data points x = {xi}, we would like to draw independent samples from the posterior p(c|x). For this, we consider expanding p(c|x) using either the labels representation, p(c1:N|x) = p(c1|x)p(c2|c1, x) . . . p(c N|c1:N 1, x), (3) or the indices representation, p(s1:K|x) = p(s1|x)p(s2|s1, x) . . . p(s K|s1:K 1, x) . (4) Note that for a given cluster configuration, p(c1:N|x) = p(s1:K|x). In the next two Sections, we present neural architectures to model the factors in each of these expansions. 3. Pointwise Sampling We would like to model all the factors in (3) in a unified way, with a generic factor given by p(cn|c1:n 1, x) = p(c1 . . . cn, x) c n=1 p(c1 . . . c n, x) Here we assumed that there are K unique values in c1:n 1, and therefore cn can take K + 1 values, corresponding to xn joining any of the K existing clusters, or forming its own new cluster. Since (5) is in general difficult to compute directly, we will approximate these terms with a neural network qθ(cn|c1:n 1, x), that takes as inputs (c1:n 1, x), then extracts features and combines them nonlinearly to output a probability distribution on cn. Critically, we will design the network to enforce the highly symmetric structure of (5). To make this symmetric structure more transparent, let us consider the joint distribution of the assignments of the first n data points, p(c1, . . . , cn, x) . (6) Neural Clustering Processes Figure 1. Encoding cluster labels. After assigning labels c1:6 to K = 2 clusters, each of the three possible c7 labels (for the circled point x7) gives an encoding Gk for the set x1:7. The vector U encodes the four gray unlabeled points (Best in color). Note that under the model (1), this quantity depends on all the N elements of x, not just on x1:n. A neural representation of (6) should respect the permutation symmetries imposed on the xi s by the values of c1:n. Therefore, our first task is to build permutation-invariant representations of the observations x. The general problem of constructing such invariant encodings was discussed recently in (Zaheer et al., 2017); to adapt this approach to our context, we consider three distinct permutation symmetries: Permutations within a cluster: (6) is invariant under permutations of xi s in the same cluster. For each of the K clusters that have been sampled so far, we define the encoding i:ci=k h(xi) h : Rdx Rdh (7) which is clearly invariant under permutations of xi s in the same cluster. In general h is an encoding function we learn from data, unless p(x|µ) belongs to an exponential family and the prior p(c1:N) is constant, as discussed in SM Section B. Permutations between clusters: (6) is invariant under permutations of the cluster labels. In terms of the withincluster invariants Hk, this can be captured by k=1 g(Hk), g : Rdh Rdg. (8) Permutations of the unassigned data points: (6) is also invariant under permutations of the N n unassigned data points. This can be captured by i=n+1 u(xi) , u : Rdx Rdu. (9) Note that G and U provide fixed-dimensional, symmetryinvariant representations of the assigned and non-assigned data points, respectively, for any values of N and K. Encodings of this form yield arbitrarily accurate approximations of (partially) symmetric functions (Zaheer et al., 2017; Gui et al., 2019). 3.1. The Variable-input Softmax After assigning values to c1:n 1, each of the K + 1 possible values for cn corresponds to h(xn) appearing in one particular Hk in (7), and yields a separate vector Gk in (8). See Figure 1 for an example. In terms of the Gk s and U, we propose to model (5) as qθ(cn = k|c1:n 1, x) = ef(Gk,U) PK+1 k =1 ef(Gk ,U) (10) with k = 1 . . . K + 1, where we have introduced a new real-valued function f. In other words, each value of cn corresponds to a different channel through which the encoding h(xn) flows to the logit value f. Note that k = K +1 corresponds to cn forming its own new cluster with Hk = h(xn). Our softmax (10) differs from the usual form in, e.g., classification networks, where a fixed number of categories receive logit values f from the fixed-size final layer of a multi-layer perceptron (MLP). In our case, the discrete identity of each logit is determined by the neural path that the input h(xn) takes to G, thus allowing a flexible number of categories. In eq. (10), θ denotes the parameters in the functions h, g, u and f, which we represent with neural networks. By storing and updating G and U for successive values of n, as shown in Algorithm 1, the computational cost of a full i.i.d. sample of c1:N is O(NK), the same as a single Gibbs sweep; and by parallelizing steps 8-9 in Algorithm 1, the number of network forward passes becomes O(N). We term this approach Neural Clustering Process (NCP). It is relatively easy to run hundreds of copies of Algorithm 1 in parallel on a GPU, with each copy yielding a different set of samples c1:N.2 3.2. Objective Function In order to train the neural networks, we use stochastic gradient descent to minimize the expected KL divergence, Ep(N)p(x)KL(p(c|x) qθ(c|x)) = (11) Ep(N)p(c1:N,x) h PN n=2 log qθ(cn|c1:n 1, x) i + const. Samples from p(c1:N, x) are obtained from the generative model, irrespective of the model being conjugate. In cases with unlimited samples (such as the 2D Gaussian example in Section 5 and the spike-sorting application in Section 8), we can potentially train a neural network to approximate p(cn|c1:n 1, x) arbitrarily accurately. The objective function (11) can be seen as a form of Expectation Propagation (Minka, 2001), as opposed 2Implementation available at https://github.com/ aripakman/neural_clustering_process Neural Clustering Processes Algorithm 1 O(NK) Neural Clustering Process 1: hi h(xi), ui u(xi) i = 1 . . . N {Notation} 2: U PN i=2 ui, K 1 {Initialize unassigned set} 3: H1 h1, G g(H1), c1 1 {First cluster} 4: for n 2 . . . N do 5: U U un {Remove xn from unassigned set} 6: HK+1 0 {We define g(0) = 0} 7: for k 1 . . . K + 1 do 8: Gk G + g(Hk + hn) g(Hk) {Add xn} 9: qk ef(Gk,U) 10: end for 11: qk qk/ PK+1 k =1 qk , cn qk {Sample} 12: if cn = K + 1 then 13: K K + 1 14: end if 15: G G g(Hcn) + g(Hcn + hn) {Add point xn} 16: Hcn Hcn + hn 17: end for 18: Return c1 . . . c N to variational inference, which would minimize instead KL(qθ(c|x) p(c|x)). Note that the gradient acts only on the variable-input softmax qθ, not on p(c, x), so there is no problem of backpropagating through discrete variables (Jang et al., 2016; Maddison et al., 2016). 4. Clusterwise Sampling While the NCP algorithm is good enough for small datasets, O(N) forward calls might be too many for large datasets. We consider now an O(K) alternative, based on modeling the factors in the clusterwise expansion (4), p(s1:K|x) = p(s1|x)p(s2|s1, x) . . . p(s K|s1:K 1, x) . (12) Sampling from p(sk|s1:k 1, x) can be done in two steps: 1. Sample uniformly an index dk from the set Ik = {1 . . . N}\{s1:k 1} of available indices (those not taken by s1:k 1). The point xdk becomes the first element of cluster k. 2. Denote by ak = (a1 . . . amk) the elements of the set of remaining indices Ik\{dk}, where mk = |Ik\{dk}|. Conditioned on (dk, s1:k 1, x), sample a binary vector bk = (b1 . . . bmk) {0, 1}mk with bi = 1 if the point xai joins cluster k. These two steps (see Figure 2 for an example) are iterated until there are no available indices left, and have probability p(dk, bk|s1:k 1, x) = p(dk|s1:k 1)p(bk|dk, s1:k 1, x) (13) Figure 2. Clusterwise sampling. Left: After sampling cluster s1 (orange), the first element of s2, d2, is sampled uniformly (green). Middle: All unassigned points a2 (grey) are candidates to join d2. Right: By sampling b2, cluster s2 is completed. (Best in color). p(dk|s1:k 1) = 1/|Ik| for dk Ik , 0 for dk / Ik , and |Ik| = mk + 1. The event indicated by sk is the union of Nk disjoint events (dk, bk), and we have p(sk|s1:k 1, x) = 1 |Ik| dk skp(bk|dk, s1:k 1, x) (14) where bk has a 1 for each element in sk except dk. Our major challenge is therefore to model the conditional p(bk|dk, s1:k 1, x), which we address next. 4.1. Factorized posterior The information contained in (dk, s1:k 1, x), is better represented by splitting the dataset as xk = (xa, xdk, xs), where xa = (xa1 . . . xamk ) mk available points for cluster k xdk First data point in cluster k xs = (xs1 . . . xsk 1) Points already assigned to clusters. Thus p(bk|xk) p(bk|dk, s1:k 1, x). Note now that this factor has a form of conditional exchangeability p(b1 . . . bmk|xa1, . . . , xamk , xdk, xs) = p(bσ1 . . . bσmk |xσa1 . . . xσamk , xdk, xs) , where σ is an arbitrary permutation of the elements of bk and xa. Based on this, we assume a conditional version of de Finetti s theorem and propose3 p(bk|xk) Z dzk i=1 pi(bi|zk, xk)p(zk|xk) , (15) 3More precisely, de Finetti s theorem (de Finetti, 1931; Hewitt & Savage, 1955) holds for infinite sequences. For finite sequences, as in our case, the result has been shown to hold only approximately and for discrete variables, both in the unconditional (Diaconis, 1977; Diaconis & Freedman, 1980) and conditional cases (Christandl & Toner, 2009). Neural Clustering Processes Figure 3. Mixture of 2D Gaussians: Given the observations in the first panel, we show samples from the NCP posterior. Note that lessreasonable samples are assigned lower probability by the NCP. The dotted ellipses indicate departures from the first, highest-probability sample. Our GPU implementation gives thousands of samples in less than a second. CCP results are similar. (Best in color.) and approximate the integrands as pθ(zk|xk) = N(zk|xk) (16) pθ,i(bi|zk, xk) = sigmoid[ρi(zk, xk)] . (17) Crucially, the posterior distributions of the bi s are conditionally independent. Therefore, after sampling p(zk|xk), all the bi s can be sampled in parallel. Thus, while a full sample of (12) of course has cost O(N), the heaviest computational burden, from network evaluations, scales as O(K), since each factor in (12) needs O(1) forward calls. As in NCP, we can get hundreds of full samples via GPU parallelization. To summarize, the elements of sk are generated in a process with latent variables dk, zk and joint distribution pθ(sk, zk, dk|s1:k 1, x) = pθ(bk|zk, xk)pθ(zk|xk)p(dk|s1:k 1) pθ(bk|zk, xk) = Qmk i=1 pθ,i(bi|zk, xk) . (18) In order to learn these functions, we introduce an encoder qφ(zk, dk|s1:k, x) to approximate the intractable posterior, and train the functions as a conditional variational autoencoder (VAE) (Sohn et al., 2015) (as we condition everything on x). The dependence of all the functions on the components of x should respect the symmetries imposed by the conditioning clusters s1:k 1 (or s1:k for qφ). This can be achieved using encodings similar to those we used above in Section 3; see SM Section A for details. Let us stress the double role of pθ(zk|xk)p(dk|s1:k 1) and pθ(bk|zk, xk). In the VAE framework, they are the priors and likelihood of a generative model for sk. On the other hand they represent, after dk, zk marginalization (14)-(15), a factor of the posterior expansion (12). We call this approach Clusterwise Clustering Process (CCP). 4.2. Objective Function Similar to the NCP case in (11), we want an approximation pθ(s1:K|x) to p(s1:K|x) that maximizes Ep(x)KL[p(s1:K|x)||pθ(s1:K|x)] (19) = Ep(x,s1:K) PK k=1 log pθ(sk|s1:k 1, x) + const. where we expanded pθ(s1:K|x) as in (12). Using now the variational posterior qφ, we can bound (19) from below, which leads us to maximize the ELBO k=1 Eqφ(zk,dk|s1:k,x) log pθ(sk, zk, dk|s1:k 1, x) qφ(zk, dk|s1:k, x) To use the reparametrization trick (Kingma & Welling, 2013), we use a Gumbel-Softmax relaxation for dk (Jang et al., 2016; Maddison et al., 2016). See SM Section A. 4.3. Estimating sample probabilities Unlike NCP, CCP samples do not come with a probability estimate. The latter can be estimated using (12) and p(bk|xk) 1 M j=1 pθ(bk|zk,j, xk) (20) where zk,j pθ(zk|xk). 5. Examples 2D Gaussian models: The generative model is α Exp(1) c1:N CRP(α) N Uniform[5, 100] µk N(0, σ2 µ12) xi N(µci, σ212) where CRP stands for the Chinese Restaurant Process, with concentration parameter α, σµ = 10, and σ = 1. Figure 3 shows that the NCP captures the posterior uncertainty inherent in clustering this data. Since we have unlimited samples, there is no distinction here between training and test sets. MNIST digits: We consider next a DPMM over MNIST digits, with generative model α Exp(1) c1:N CRP10(α) N Uniform[5, 100] lk Unif[0, 9] without replacement. k = 1 . . . K xi Unif[MNIST digits with label lci] i = 1 . . . N where CRP10 is a Chinese Restaurant Process truncated to up to 10 clusters, and dx = 28 28. Training was Neural Clustering Processes Figure 4. NCP trained on MNIST clusters. Top row: 20 images from the MNIST test set. Below: five samples of c1:20 from the NCP posterior. Note that each sample captures some ambiguity suggested by the form of particular digits. CCP results are similar. performed by sampling xi from the MNIST training set. Figure 4 shows posterior samples for a set of digits from the MNIST test set, illustrating how the estimated model correctly captures the shape ambiguity of some of the digits. Note that in this case the generative model has no analytical expression, but this presents no problem; a set of labelled samples is all we need for training. See SM Section G for details of the network architectures used. 6. Related works Most works on neural network-based clustering focus on learning features as inputs to traditional clustering algorithms, as reviewed in (Du, 2010; Aljalbout et al., 2018; Min et al., 2018). Our approach differs from these works because it leverages deep learning to improve algorithmic aspects of clustering, via amortization. Permutation-invariant neural architectures have been explored recently in (Ravanbakhsh et al., 2017; Korshunova et al., 2018; Lee et al., 2018; Bloem-Reddy & Teh, 2019; Wagstaff et al., 2019). The representation of a set via a sum (or mean) of encoding vectors was also used in (Guttenberg et al., 2016; Ravanbakhsh et al., 2016; Edwards & Storkey, 2017; Zaheer et al., 2017; Garnelo et al., 2018a; Kim et al., 2019). A conditional form of de Finetti s theorem was also assumed for Neural Processes (NP) (Garnelo et al., 2018b), but differs from our assumed form in (15) in that our prior pθ(zk|xk) depends symmetrically on the available points xa, in order to keep the correct dependency of the marginal p(c1:n, x) on all the N components of x, while for NPs the prior is independent of the available data points. Amortized inference of Gaussian mixtures has been studied recently in (Le et al., 2016; Lee et al., 2018; Kalra et al., 2019). In these works the output of the network are the mixture parameters instead of sampled discrete labels, and the number of components is either bounded (Le et al., 2016) or fixed (Lee et al., 2018; Kalra et al., 2019). Closer to our CCP is the DAC approach (Lee et al., 2019), that uses the set attention mechanism of (Lee et al., 2018) in the encoder to iteratively isolate and eliminate one cluster per iteration, in O(K) network evaluations. But the clusters have no clear interpretation in terms of the generative model, as they come from hard thresholding of sigmoids and the eliminated clusters do not appear as a conditioning context to find new clusters. We summarize these comparisons in Table 1. Property CCP NCP DAC Mo G Unlimited components Amortized labels Any generative model Well defined posterior - Forward passes O(K) O(N) O(K) O(1) Table 1. Comparing amortized clustering approaches. We compare NCP/CCP (our methods) with DAC (Lee et al., 2019) and amortization for mixtures of Gaussians (Mo G) (Le et al., 2016; Lee et al., 2018; Kalra et al., 2019). 7. Evaluations and diagnostics The examples in Section 5 provide strong qualitative evidence that our approximations to the true posteriors in these models capture the uncertainty inherent in the observed data. But we would like to go further and ask quantitatively how well our approximations match the exact posterior. Unfortunately, for sample sizes much larger than N = O(10) it is impossible to compute the exact posterior in these models. Nonetheless, there are several quantitative metrics we can examine to check the accuracy of the model output. Note that the diagnostics below that rely on the probabilistic nature of the inferred clusters are not applicable to the other Neural Clustering Processes Figure 5. Quantitative Evaluations. Upper left: Two clusters of 20 points each and a line over possible locations of a 41st last point. Upper right: Assuming the 2D model from (21), the posterior p(c41|c1:40, x) can be computed exactly, and we compare it to the NCP estimate as a function of the horizontal coordinate of x41, as this point moves over the gray line on the upper left panel. Geweke s Tests. Lower left: The curves compare the exact mean ( std.) of the number of clusters K for different N s from the CRP prior (α = 0.7), with CCP sampled estimates using eq. (21). Lower right: Similar comparison for the histogram of K for N = 30 points. methods compared in Table 1. NCP vs. CCP: The results from the two approaches were similar in all the examples we considered, such as those in Section 5. Training CPP, however, presents the usual challenges of VAEs. We found it useful to use multiple sample objectives (Burda et al., 2015) and estimate the gradient using double-reparametrization (Tucker et al., 2019). Global symmetry from exchangeability: From the exchangeability of p(c1:N|α1), the expansion (3) should not depend on the order of the data points, but this symmetry is not enforced explicitly. If our model learns the conditional probabilities correctly, this symmetry should be (approximately) satisfied, as we show in SM Section C. Estimated vs. Analytical Probabilities: Some conditional probabilities can be computed analytically and compared with the estimates output by the network; in the example shown in Figure 5, upper-right, the estimated probabilities are in close agreement with their exact values. Geweke s Tests: A popular family of tests that check the correctness of MCMC implementations (Geweke, 2004) can also be applied in our case: verify the (approximate) identity between the prior p(c1:N) and qθ(c1:N) Z dx qθ(c1:N|x) p(x) , (21) where p(x) is the marginal from the generative model. Figure 5 shows such a comparison for the 2D Gaussian DPMM from Section 5, showing excellent agreement. Comparison with MCMC: NCP/CCP have some advantages over MCMC approaches. First, unlike MCMC, we get a probability estimate for each sample, either directly (NCP) or with minimal computation (CCP). Secondly, NCP/CCP enjoy higher efficiency, due to parallelization of iid samples. For example, in the Gaussian 2D example in eq.(21), in the time a collapsed Gibbs sampler produces one (correlated) sample, our GPU-based NCP implementation produces more than 100 iid approximate samples. Finally, NCP/CCP do not need a burn-in period. Comparison with Variational Inference: Below we compare NCP with a variational approach on spike sorting. For 2000 spikes, the latter returned one clustering estimate in 0.76 secs., but does not properly handle the uncertainty about the number of clusters. NCP produced 150 clustering configurations in 10 secs., efficiently capturing clustering uncertainty. In addition, the variational approach requires a preprocessing step that projects the samples to lower dimensions, whereas NCP directly consumes the high-dimensional data by learning an encoder function h. 8. Application: spike sorting with NCP Large-scale neural population recordings using multielectrode arrays (MEA) are crucial for understanding neural circuit dynamics. Each MEA electrode reads the signals from many neurons, and each neuron is recorded by multiple nearby electrodes. As a key analysis step, spike sorting converts the raw signal into a set of spike trains belonging to individual neurons (Pachitariu et al., 2016; Chung et al., Neural Clustering Processes 2017; Jun et al., 2017; Lee et al., 2017; Chaure et al., 2018; Carlson & Carin, 2019). At the core of many spike sorting pipelines is a clustering algorithm that groups the detected spikes into clusters, each representing a putative neuron (Figure 6). However, clustering spikes can be challenging: (1) Spike waveforms form highly non-Gaussian clusters in spatial and temporal dimensions, and it is unclear what are the optimal features for clustering. (2) It is unknown a priori how many clusters there are. (3) Existing methods do not perform well on spikes with low signal-to-noise ratios (SNR) due to increased clustering uncertainty, and fully Bayesian approaches proposed to handle this uncertainty (Wood & Black, 2008; Carlson et al., 2013) do not scale to large datasets. To address these challenges, we propose a novel approach to spike clustering using NCP. We consider the spike waveforms as generated from a Mixture of Finite Mixtures (MFM) distribution (Miller & Harrison, 2018), which can be effectively modeled by NCP. (1) Rather than selecting arbitrary features for clustering, the spike waveforms are encoded with a convolutional neural network (Conv Net), which is learned end-to-end jointly with the NCP network to ensure optimal feature encoding. (2) Using a variableinput softmax function, NCP is able to perform inference on cluster labels without assuming a fixed or maximum number of clusters. (3) NCP allows for efficient probablistic clustering by GPU-parallelized posterior sampling, which is particularly useful for handling the clustering uncertainty of ambiguous small spikes. (4) The computational cost of NCP training can be highly amortized, since neuroscientists often sort spikes form many statistically similar datasets. We trained NCP for spike clustering using synthetic spikes from a simple yet effective generative model that mimics the distribution of real spikes, and evaluated the spike sorting performance on labeled synthetic data, unlabeled real data and hybrid test data by comparing NCP against two other methods: (1) v GMFM, variational inference on Gaussian MFM (Hughes & Sudderth, 2013). (2) Kilosort, a state-of-the-art spike sorting pipeline described in (Pachitariu et al., 2016). In the Supplementary Material (SM) Section D, we describe the dataset, neural architecture, and the training/inference pipeline of NCP spike sorting. In SM Section E, we show that NCP spike sorting achieves high clustering quality, and matches or outperforms a state-ofthe-art method on synthetic, real and hybrid data. Probabilistic clustering of ambiguous small spikes. Sorting small spikes has been challenging due to the low SNR and increased uncertainty of cluster assignment. By efficient GPU-parallelized posterior sampling of cluster labels, NCP is able to handle the clustering uncertainty by producing multiple plausible clustering configurations. Figure 7 shows examples where NCP separates spike clusters with ampli- Observations Clusters Average Overlay Figure 6. Clustering multi-channel spike waveforms using NCP. Each row is an electrode channel. Spikes of the same color belong to the same cluster. (Scale bar: 5 noise s.d.). tude as low as 3-4 the standard deviation of the noise into plausible units that are not mere scaled version of each other but have distinct shapes on different channels. f 1 Cluster 2 Clusters 2 Clusters ch0 1 Cluster Example 1 Example 2 Figure 7. Clustering ambiguous small spikes. In both examples, multiple plausible clustering results of small spikes were produced by sampling from the NCP posterior (scale bar = 5 noise s.d.). 9. Conclusion We introduced neural architectures to amortize posterior sampling of generative clustering models in O(N) and O(K) forward passes. The performance is excellent in simple examples. In a realistic spike-sorting application, our results show that NCP spike sorting provides high clustering quality, matches or outperforms a state-of-the-art method, and handles clustering uncertainty by efficient posterior sampling (a task that is not solved by currently available methods), demonstrating substantial promise for incorporating these methods into production-scale pipelines. Neural Clustering Processes Acknowledgements We thank Sean Bittner, Alessandro Ingrosso, Scott Linderman, Aaron Schein and Ruoxi Sun for helpful conversations. This work was supported by the Simons Foundation, the DARPA NESD program, ONR N00014-17-1-2843, NIH/NIBIB R01 EB22913, NSF Neuro Nex Award DBI1707398 and The Gatsby Charitable Foundation. Aljalbout, E., Golkov, V., Siddiqui, Y., and Cremers, D. Clustering with Deep Learning: Taxonomy and New Methods. ar Xiv preprint ar Xiv:1801.07648, 2018. Blei, D. M. and Jordan, M. I. Variational Methods for the Dirichlet Process. In Proceedings of the Twenty-first International Conference on Machine Learning, ICML, 2004. Bloem-Reddy, B. and Teh, Y. W. Probabilistic symmetry and invariant neural networks. ar Xiv preprint ar Xiv:1901.06082, 2019. Burda, Y., Grosse, R., and Salakhutdinov, R. Importance weighted autoencoders. ar Xiv preprint ar Xiv:1509.00519, 2015. Calabrese, A. and Paninski, L. Kalman filter mixture model for spike sorting of non-stationary data. Journal of neuroscience methods, 196(1):159 169, 2011. Carlson, D. and Carin, L. Continuing progress of spike sorting in the era of big data. Current opinion in neurobiology, 55:90 96, 2019. Carlson, D. E., Vogelstein, J. T., Wu, Q., Lian, W., Zhou, M., Stoetzner, C. R., Kipke, D., Weber, D., Dunson, D. B., and Carin, L. Multichannel electrophysiological spike sorting via joint dictionary learning and mixture modeling. IEEE Transactions on Biomedical Engineering, 61(1):41 54, 2013. Chaure, F. J., Rey, H. G., and Quian Quiroga, R. A novel and fully automatic spike-sorting implementation with variable number of features. Journal of neurophysiology, 120(4):1859 1871, 2018. doi: 10.1152/jn.00339.2018. Chichilnisky, E. J. and Kalmar, R. S. Functional asymmetries in on and off ganglion cells of primate retina. Journal of Neuroscience, 22(7):2737 2747, 2002. ISSN 0270-6474. doi: 10.1523/JNEUROSCI. 22-07-02737.2002. URL http://www.jneurosci. org/content/22/7/2737. Christandl, M. and Toner, B. Finite de Finetti theorem for conditional probability distributions describing physical theories. Journal of Mathematical Physics, 50(4):042104, 2009. Chung, J. E., Magland, J. F., Barnett, A. H., Tolosa, V. M., Tooker, A. C., Lee, K. Y., Shah, K. G., Felix, S. H., Frank, L. M., and Greengard, L. F. A fully automated approach to spike sorting. Neuron, 95(6):1381 1394, 2017. de Finetti, B. Funzione caratteristica di un fenomeno aleatorio. Atti della R. Academia Nazionale dei Lincei, Serie 6. Memorie, Classe di Scienze Fisiche, Mathematice e Naturale,4, pp. 251 299, 1931. Diaconis, P. Finite forms of de Finetti s theorem on exchangeability. Synthese, 36(2):271 281, 1977. Diaconis, P. and Freedman, D. Finite exchangeable sequences. The Annals of Probability, pp. 745 764, 1980. Du, K.-L. Clustering: A neural network approach. Neural networks, 23(1):89 107, 2010. Edwards, H. and Storkey, A. Towards a neural statistician. ICLR, 2017. Garnelo, M., Rosenbaum, D., Maddison, C. J., Ramalho, T., Saxton, D., Shanahan, M., Teh, Y. W., Rezende, D. J., and Eslami, S. Conditional neural processes. In International Conference on Machine Learning, 2018a. Garnelo, M., Schwarz, J., Rosenbaum, D., Viola, F., Rezende, D. J., Eslami, S., and Teh, Y. W. Neural processes. In ICML 2018 workshop on Theoretical Foundations and Applications of Deep Generative Models, 2018b. Gershman, S. and Goodman, N. Amortized inference in probabilistic reasoning. In Proceedings of the annual meeting of the cognitive science society, volume 36, 2014. Geweke, J. Getting it right: Joint distribution tests of posterior simulators. Journal of the American Statistical Association, 99(467):799 804, 2004. Graves, A. Sequence transduction with recurrent neural networks. Co RR, abs/1211.3711, 2012. Gui, S., Zhang, X., Zhong, P., Qiu, S., Wu, M., Ye, J., Wang, Z., and Liu, J. Pine: Universal deep embedding for graph nodes via partial permutation invariant set functions. ar Xiv preprint ar Xiv:1909.12903, 2019. Guttenberg, N., Virgo, N., Witkowski, O., Aoki, H., and Kanai, R. Permutation-equivariant neural networks applied to dynamics prediction. ar Xiv preprint ar Xiv:1612.04530, 2016. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016. Neural Clustering Processes Hewitt, E. and Savage, L. J. Symmetric measures on cartesian products. Transactions of the American Mathematical Society, 80(2):470 501, 1955. Hinton, G. E., Dayan, P., Frey, B. J., and Neal, R. M. The wake-sleep algorithm for unsupervised neural networks. Science, 268(5214):1158 1161, 1995. Hughes, M., Kim, D. I., and Sudderth, E. Reliable and scalable variational inference for the hierarchical Dirichlet process. In Artificial Intelligence and Statistics, pp. 370 378, 2015. Hughes, M. C. and Sudderth, E. Memoized online variational inference for dirichlet process mixture models. In Advances in Neural Information Processing Systems 26, pp. 1133 1141. 2013. Jain, S. and Neal, R. M. A split-merge Markov chain Monte Carlo procedure for the Dirichlet process mixture model. Journal of computational and Graphical Statistics, 13(1): 158 182, 2004. Jang, E., Gu, S., and Poole, B. Categorical reparameterization with Gumbel-softmax. ar Xiv preprint ar Xiv:1611.01144, 2016. Jun, J. J., Mitelut, C., Lai, C., Gratiy, S. L., Anastassiou, C. A., and Harris, T. D. Real-time spike sorting platform for high-density extracellular probes with ground-truth validation and drift correction. bio Rxiv, 2017. Kalra, S., Adnan, M., Taylor, G., and Tizhoosh, H. Learning permutation invariant representations using memory networks. ar Xiv preprint ar Xiv:1911.07984, 2019. Kim, H., Mnih, A., Schwarz, J., Garnelo, M., Eslami, A., Rosenbaum, D., Vinyals, O., and Teh, Y. W. Attentive neural processes. ICLR, 2019. Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. ICLR, 2015. Kingma, D. P. and Welling, M. Auto-encoding variational bayes. ar Xiv preprint ar Xiv:1312.6114, 2013. Korshunova, I., Degrave, J., Huszar, F., Gal, Y., Gretton, A., and Dambre, J. Bruno: A deep recurrent model for exchangeable data. In Advances in Neural Information Processing Systems 31, 2018. Kurihara, K., Welling, M., and Teh, Y. W. Collapsed Variational Dirichlet Process Mixture Models. In IJCAI, volume 7, pp. 2796 2801, 2007. Le, T. A., Baydin, A. G., and Wood, F. Inference compilation and universal probabilistic programming. ar Xiv preprint ar Xiv:1610.09900, 2016. Lee, J., Lee, Y., Kim, J., Kosiorek, A. R., Choi, S., and Teh, Y. W. Set transformer. ar Xiv preprint ar Xiv:1810.00825, 2018. Lee, J., Lee, Y., and Teh, Y. W. Deep Amortized Clustering. ar Xiv:1909.13433, 2019. Lee, J. H., Carlson, D. E., Razaghi, H. S., Yao, W., Goetz, G. A., Hagen, E., Batty, E., Chichilnisky, E., Einevoll, G. T., and Paninski, L. Yass: Yet another spike sorter. In Advances in Neural Information Processing Systems, pp. 4002 4012, 2017. Maddison, C. J., Mnih, A., and Teh, Y. W. The concrete distribution: A continuous relaxation of discrete random variables. ar Xiv preprint ar Xiv:1611.00712, 2016. Mc Lachlan, G. J. and Basford, K. E. Mixture models: Inference and applications to clustering, volume 84. Marcel Dekker, 1988. Miller, J. W. and Harrison, M. T. Mixture models with a prior on the number of components. Journal of the American Statistical Association, 113(521):340 356, 2018. Min, E., Guo, X., Liu, Q., Zhang, G., Cui, J., and Long, J. A survey of clustering with deep learning: From the perspective of network architecture. IEEE Access, 6: 39501 39514, 2018. Minka, T. P. Expectation propagation for approximate Bayesian inference. In Proceedings of the Seventeenth conference on Uncertainty in artificial intelligence, pp. 362 369. Morgan Kaufmann Publishers Inc., 2001. Neal, R. M. Markov chain sampling methods for Dirichlet process mixture models. Journal of computational and graphical statistics, 9(2):249 265, 2000. Pachitariu, M. Kilosort2. https://github.com/ Mouse Land/Kilosort2, 2019. Pachitariu, M., Steinmetz, N., Kadir, S., Carandini, M., and Harris, K. D. Kilosort: realtime spike-sorting for extracellular electrophysiology with hundreds of channels. Bio Rxiv, pp. 061481, 2016. Paige, B. and Wood, F. Inference networks for sequential Monte Carlo in graphical models. In International Conference on Machine Learning, pp. 3040 3049, 2016. Pakman, A. and Paninski, L. Amortized Bayesian inference for clustering models. BNP@Neur IPS 2018 Workshop All of Bayesian Nonparametric, 2018. Pakman, A., Wang, Y., and Paninski, L. Neural Permutation Processes. In Symposium on Advances in Approximate Bayesian Inference, pp. 1 7, 2019. Neural Clustering Processes Parthasarathy, N., Batty, E., Falcon, W., Rutten, T., Rajpal, M., Chichilnisky, E., and Paninski, L. Neural Networks for Efficient Bayesian Decoding of Natural Images from Retinal Neurons. In Advances in Neural Information Processing Systems 30, pp. 6434 6445. 2017. Ravanbakhsh, S., Schneider, J., and Poczos, B. Deep learning with sets and point clouds. ar Xiv preprint ar Xiv:1611.04500, 2016. Ravanbakhsh, S., Schneider, J., and P oczos, B. Equivariance through parameter-sharing. In Proceedings of the 34th International Conference on Machine Learning, 2017. Ritchie, D., Horsfall, P., and Goodman, N. D. Deep amortized inference for probabilistic programs. ar Xiv preprint ar Xiv:1610.05735, 2016. Rodriguez, A. and Mueller, P. NONPARAMETRIC BAYESIAN INFERENCE. NSF-CBMS Regional Conference Series in Probability and Statistics, 9:i 110, 2013. Shan, K. Q., Lubenov, E. V., and Siapas, A. G. Model-based spike sorting with a mixture of drifting t-distributions. Journal of neuroscience methods, 288:82 98, 2017. Sohn, K., Lee, H., and Yan, X. Learning structured output representation using deep conditional generative models. In Advances in neural information processing systems, pp. 3483 3491, 2015. Stuhlm uller, A., Taylor, J., and Goodman, N. Learning stochastic inverses. In Advances in neural information processing systems, pp. 3048 3056, 2013. Sun, R. and Paninski, L. Scalable approximate Bayesian inference for particle tracking data. In Proceedings of the 35th International Conference on Machine Learning, 2018. Sutskever, I., Vinyals, O., and Le, Q. V. Sequence to sequence learning with neural networks. In NIPS, 2014. Tucker, G., Lawson, D., Gu, S., and Maddison, C. J. Doubly reparameterized gradient estimators for monte carlo objectives. In International Conference on Learning Representations, 2019. URL https://openreview. net/forum?id=Hk G3e205K7. Vinh, N. X., Epps, J., and Bailey, J. Information theoretic measures for clusterings comparison: Variants, properties, normalization and correction for chance. Journal of Machine Learning Research, 11(Oct):2837 2854, 2010. Wagstaff, E., Fuchs, F. B., Engelcke, M., Posner, I., and Osborne, M. On the limitations of representing functions on sets. ar Xiv preprint ar Xiv:1901.09006, 2019. Wang, Y., Pakman, A., Mitelut, C., Lee, J., and Paninski, L. Spike sorting using the neural clustering process. Real Neurons & Hidden Units Workshop @ Neur IPS 2019, 2019. Wood, F. and Black, M. J. A nonparametric bayesian alternative to spike sorting. Journal of neuroscience methods, 173(1):1 12, 2008. Zaheer, M., Kottur, S., Ravanbakhsh, S., P oczos, B., Salakhutdinov, R., and Smola, A. J. Deep sets. In Advances in neural information processing systems, 2017.