# learning_partitions_from_context__90dc5005.pdf Learning Partitions from Context Simon Buchholz Department for Empirical Inference Max Planck Institute for Intelligent Systems Tübingen AI Center Tübingen, Germany sbuchholz@tue.mpg.de In this paper, we study the problem of learning the structure of a discrete set of N tokens based on their interactions with other tokens. We focus on a setting where the tokens can be partitioned into a small number of classes, and there exists a real-valued function f defined on certain sets of tokens. This function, which captures the interactions between tokens, depends only on the class memberships of its arguments. The goal is to recover the class memberships of all tokens from a finite number of samples of f. We begin by analyzing this problem from both complexity-theoretic and information-theoretic viewpoints. We prove that it is NP-complete in general, and for random instances, we show that on the order of N ln(N) samples, implying very sparse interactions, suffice to identify the partition. We then investigate the conditions under which gradient flow dynamics of token embeddings can reveal the class structure, finding that this is achievable in certain settings when given on the order of N 2 ln2(N) samples. 1 Introduction Modern machine learning systems are able to learn extremely complicated relations from data. They often rely on learned embeddings of discrete tokens in a continuous space. This is notably true for Large Language Models (LLMs) [9, 17, 25, 23] which encode their input by converting text into a sequence of discrete tokens that are embedded in a high dimensional embedding space and those embeddings are fed into, e.g., a transformer architecture [28] which allows predicting the next token. But also in the other domains, e.g., in vision, discrete embeddings are frequently used as a component of deep learning architectures [27, 10] as this enables capturing complex concepts that are often discrete. It was observed that after training, these word embeddings exhibit many interesting structures. The most prominent example probably is the observation that the difference of the Word2Vec embedding vectors of the nouns king and queen approximately equals the difference of the embeddings of man and woman [21, 22]. Similarly, it was found that using word similarity as an inductive bias to structure latent spaces helps downstream performance [4]. Thus, a properly structured latent space seems to be an important ingredient to capture the intricate correlations in complex data. A proper theoretical understanding of such complex models currently remains an elusive goal. However, there have been various attempts to understand various components of deep learning models. Many works investigated the behavior of feedforward-networks in particular focusing on shallow networks [2, 8] and asymptotic regimes [15, 29]. More recently, several works investigated transformer architectures (often focusing on the one layer case with linearized attention mechanism) [16, 1, 30, 12]. On a technical level [11] is closely related as they study the large depth limit of transformers through the lens of particle systems (however in their case time corresponds to depth, while in our case it corresponds to training time). A feature shared by many of those works is that 38th Conference on Neural Information Processing Systems (Neur IPS 2024). little structure is assumed on the input, i.e., fixed token embeddings are assumed and often those are even assumed to be isotropic Gaussian. Here we instead focus on the dynamics of the token embeddings and study how these can recover structure present in the data. There is no ground truth target for the embeddings for large scale models used in deep learning and their embeddings need to capture a variety of nuanced correlations and relations that are hard to formalize. Therefore, we focus on a simplified problem that nevertheless shares important features with more complex real world settings. One general heuristic is that the embeddings contain information about the similarity of tokens. We focus here on the strongest form of similarity, namely equality. Indeed, our central assumption is that tokens can be clustered in a small number of groups such that tokens within a cluster behave exactly the same, i.e., they interact in the same manner with other tokens. Then the central questions is what assumptions allow us to recover a hidden structure. The crucial feature of this setting is that we only get information about a token by its interaction with other tokens about which we also only learn through their interaction behavior. Moreover, those interactions are typically sparse, i.e., we observe only a small subset of all possible interactions. Note that this setting resembles observations made in the context of collective behavior where a global structure emerges from local interactions [13, 24, 3]. A related question was investigated in [6] where they study associative memory and also want to identify a hidden structure, however, they learn the class memberships directly through the interaction with a class embedding (and they train the interaction instead of the embedding). In [19] the dynamics of word embeddings and transformer layers was investigated when the data follows a topic model. This work shares the crucial feature that membership of a word in a certain topic is only transmitted through the co-occurrence with other words from the same topic. In contrast to their work, we here do not focus on learning class membership from token frequencies and in fact consider uniform token distribution. Instead, we view this problem as a logical inference problem: Given a set of facts about a set of tokens can the hidden structure of the tokens discovered. We summarize the main contributions of the paper as follows: We introduce a learning problem that shares important features with learning through the interaction behavior that is crucial for LLMs and complex systems. We analyze this problem from a complexity-theoretic viewpoint where we show that it is in general hard, and from an information-theoretic viewpoint where we show, roughly, that for N different tokens in the alphabet order N ln(N) samples are sufficient to identify the latent structure. We then carefully investigate the gradient dynamics of token embeddings, finding local recovery of the cluster structure and global recovery for tensor-product functions on the tokens if we have more than N 2 ln(N) samples for an alphabet with N tokens. Notation. We write [N] = {1, . . . , N} for the set of the first N integers. The cardinality of a finite set A is denoted by |A| and we also denote the standard Euclidean norm of any vector v Rd by |v|. The expressions λmax(A) and λmin(A) denote the largest and smallest eigenvalue of a symmetric matrix A respectively. We denote the uniform distribution over a set A by U(A). For two subsets A, B Rd we denote by A + B = {a + b : a A, b B} their Minkowski sum. We denote the permutation group on k elements by Sk. An overview of the used variable names can be found in Appendix A. 2 Setting and Motivation In this section, we illustrate our problem with an example and define the setup more formally. Consider the set of all animals. Those can be grouped into classes such as mammals, birds, or reptiles (in fact there is a rich hierarchical structure which we ignore here). Those groups were conceived by findings sets of animals that share many properties. Once these groups are found, we can predict unobserved properties by first identifying the cluster to which an animal belongs and then predict that the property is shared with animals in the same cluster. Note that this is a specific instance of a general problem in scientific inference, where we want to uncover a hidden grouping of similar entities from sparse observations about these entities. i = 1 i = 2 i = 3 Figure 1: Illustration of the setting for I = 3 different groups clustered in 3, 2, and 3 subgroups respectively. Samples consist of one element of each group, the dashed lines indicate samples (1, 3, 1) and (3, 7, 6). Here our main motivation, however, stems from the analysis of large language models where a similar problem arises implicitly during training. They are trained by next token prediction, so we do not expect them to learn structure by deductive reasoning such as cows are mammals, and mammals have lungs, so cows have lungs. Instead, their learning signal is whether a token can be replaced by another token for a given context. Thus, it is a natural question whether gradient descent-based training on token embeddings can uncover a hidden cluster structure of the data. Note that if the hidden structure is recovered, then generalization to unseen prompts is possible. We now introduce our formal setup that captures key aspects of the discussion. We consider I sets of N1, N2, . . . , NI tokens or entities (such as words). For simplicity, we identify these with tokens from the set [Ni]. For each of the sets [Ni] there is a partition Pi in Ki classes which we can identify with the set [Ki]. Then we can encode the partitions through maps Πi : [Ni] [Ki] so that the partition is given by Pi = (Π 1 i (ki))ki [Ki], i.e., Πi encode the class membership. We consider the map Π = Π1 . . . ΠI, i.e., Π(n1, . . . , n I) = (Π1(n1), . . . , ΠI(n I)). (1) This structure is illustrated in Figure 1 Now we assume that there is a function g : [K1] [K2] [KI] R which depends only on the classes. The map f = g Π extends this map to the tokens such that it only depends on the group a token belongs to. In the case where g (and thus f) maps to {0, 1} this can be interpreted as truth assignments, i.e., a statement consisting of a sequence (n1, . . . , n I) is true if f(n1, . . . , n I) = 1 and false otherwise and this case is the main motivation of our work. More generally, f could output the index of a suitable next token where in the 0, 1 case 0 could correspond to a negation while 1 to an end of sentence token. Our goal is to learn the partitions Pi or, equivalently, Π up to a permutation and thereby identify the hidden structure. We assume that we are given data in the form of samples (ns, f(ns) where ns = (ns 1, . . . , ns I) and ns i [Ni]. In other words, we try to learn the underlying structure from the interactions of a token with the other tokens, which is the same for every element of the partition. To simplify the notation and statements, we assume in the following that Ni = N and Ki = K for some N, K and all 1 i I. Our main interest concerns the case where N is large, i.e., there are many entities and K and I are small, i.e., the number of groups is small. Let us summarize several features that this model shares with real world problem, such as learning suitable embeddings in latent space. Hierarchical structures, i.e., groups of objects that share certain features, as discussed here, are abundant in language and science. We only receive an indirect learning signal for the value of Πi(n) through its interaction with other tokens. Interactions can be very complex, i.e., here the output depends on the interaction of I different tokens and ignoring parts of the context makes learning infeasible. On the other hand, many important features are abstracted away, e.g.: Here we assume that tokens from the same element of the partition interact in exactly the same way with other tokens while in reality there are many different partitions of the tokens depending on the broader context (e.g., we can group species by habitat, color, or, size each resulting in different partitions) or there are exceptions, e.g., mammals generally do not lay eggs but the platypus does. Many more complex notions of similarity or further properties of embeddings such as a vector space structure are not covered. Also, there can be many uninformative features. We do not consider noisy data or errors in this work, which is crucial for real world applications. 3 Complexity-Theoretic and Information-Theoretic Analysis We now study this learning problem in different settings. Let us first briefly discuss complexitytheoretic and information-theoretic properties of the learning problem to understand the general boundaries of this learning task. We first study the information-theoretic viewpoint, i.e., the question of how many samples are necessary to identify Π (and potentially g). We focus on the case where we sample Π and the data-samples uniformly at random. To learn an unstructured map [N]I R we generally need of the order N I ln(N I) independent samples (N I when sampling without replacement). For the structured setting we show that if Π is drawn uniformly at random then generally order KIN ln(N) samples are sufficient to learn f and the partition induced by Π. In other words, for every token ni and each of the KI classes Π 1(k) we need of the order of ln(N) samples n such that Π(n) = k and ni = ni. In particular, for N KI any token will interact only with KI ln(N) N other tokens, i.e., a very sparse subset of the other tokens. We require the following necessary condition for identifiability: For every ki = k i there are k1, . . . , ki 1, ki+1, . . . , k I [K] such that g(k1, . . . , ki 1, ki, ki+1, . . . , k I) = g(k1, . . . , ki 1, k i, ki+1, . . . , k I). (2) Note that if this condition is indeed necessary because if it is not satisfied then it is not possible to distinguish Π 1 i (ki) and Π 1 i (k i). Clearly, we can generally only identify Π up to a permutation symmetry, i.e., we can only find Π such that there are permutations πi SK such that Πi = πi Πi. We have the following result. Theorem 1. Assume that g : [K]I R is a function satisfying the assumption (2). Assume we randomly sample maps Πi such that Πi(ni) = ki with probability K 1 for all i, ni, and ki and such that (Πi(ni))i I,ni [N] are independent. Assume we are given S samples (n, g Π(n)) where n U([N]I). Then there is a constant N0(I, K, η) such that with probability at least 1 2e η for N N0(I, K, η) and S 22I+3IKIN ln(N) (3) we can recover Π and g up to permutations of [K]. This result is a special case of Theorem 6 in Appendix B which shows similar bounds for arbitrary maps Π that are not necessarily random. In the more general setting, there are additional dependencies on the size of the preimages Π 1 i (k). Note that this dependency cannot be avoided because if there is a k such that |Π 1(k)| = 1 and g(k) = 1 and g(k ) = 0 for k = k then order N I samples are necessary to find n such that Π(n) = k and thus Π. The general proof idea is to bound the probability that any fixed Π = Π is compatible with the dataset. It turns out that it is possible to bound this probability in terms of the partitions distance of the partitions induced by Π and Π . Then we are left with bounding the number of partitions, and we conclude with the union bound. We now show that this bound is essentially tight. Theorem 2. Let g : [K]I R be a function such that g(1, k2, . . . , k I) = g(2, k2, . . . , k I) for all k2, . . . , k I [K] except when k2 = k3 = . . . = k I = 1. Assume that N is divisible by K and that |Π 1 i (k)| = N/K for all i [I] and k [K]. Given 3 S NKI 1 ln(N/K)/4 samples (ns, g Π(ns)) where n U([N]I) i.i.d. Then the function Π is identifiable with probability at most 2e The proof of this result can be found in Appendix B. Next, we emphasize that while typically a rather small number of samples is sufficient to learn Π it can generally be very hard to do this in practice. More concretely, we show that for I 3 even deciding whether there is a map Π = Π1 . . . ΠI : [N]I [K]I such that f = g Π given access to samples of the form (ns, ts) is NP-complete. We show that this is true even if g is known, I = 3 and K = 2. Theorem 3. Consider the map g : {0, 1}3 {0, 1} given by g(k1, k2, k3) = 1k1+k2+k3=2. (4) Then it is an NP-complete problem to decide given samples of the form (ns 1, ns 2, ns 3, ts) [N]3 {0, 1} whether there is a map Π = Π1 Π2 Π3 with Πi : [N] {0, 1} such that ts = g Π(ns 1, ns 2, ns 3) for all samples. The proof of this result can be found in Appendix C. Now that we established under what the conditions Π can in principle be learned, and clarified that this might be hard in general, we next discuss how we can find Π in practice. First, we remark that Theorem 3 rules out the existence of any general fast algorithms to learn Π. Given the combinatorial nature and the hardness of the problem, it is natural to reformulate the task as a constraint satisfaction problem which can then be solved using standard SAT solvers (see, e.g., [14] for a review). Indeed, we can introduce Boolean variables ti kn for i [I], k [K], and n [N] which encode whether Πi(n) = k and rkv for every k [K]I and v Im(f) in the (finite) image of f that encode whether g(k) = v. It is then relatively straightforward to then express the conditions for the map Π as a constraint satisfaction problem which is satisfiable if and only if there are maps Π and g such that ts = g(Π(ns)) holds for all samples. We outline the construction in more detail in Appendix C. We leave the task of developing and studying efficient algorithms for the considered problem for future work because the main motivation of this paper is rather to understand how the complex statistical patterns can be extracted using simple gradient based algorithms. This will be investigated in the next section. 4 Analysis of Gradient Descent Dynamics In this section, we investigate under what conditions the clustering induced by Π can be learned using gradient descent on token embeddings. Our main finding is that for uniformly random Π and S sufficiently large gradient descent can be used to uncover or at least preserve the cluster structure of the embeddings. This shows that while the general problem is NP hard typical random instances with sufficiently many samples can be solved with straightforward algorithms quickly. This is in spirit similar to the results found in, e.g., [5]. Let us start by introducing the setting. Setting. We assume that we have token embeddings for each of the I sets [N1] to [NI], i.e., we assume that there are vectors v(i, n) RD for some D and all i [I], n [N]. Based on these embeddings, we assume that we are given a function ˆf : RID R that transforms the embeddings into a prediction. We will abuse notation and write for n [N]I ˆf(n) = ˆf(v(1, n1), . . . , v(I, n I)), (5) i.e., we will suppress the map from tokens to embeddings in the notation. Now we consider gradient descent for the embedding vectors using the least square loss on training samples, i.e., the loss of a sample (n, t = f(n)) is ( ˆf(n) f(n))2. We assume that we are given a dataset D = {n1, . . . , n S} U([N]I)S. Then the empirical loss reads (the division by 2 is convenient for the gradient evaluation later) ˆL((v(i, n))i [I],n [N]) = 1 s=1 ( ˆf(ns) f(ns))2. (6) We also define shorthands for certain concatenations of embeddings. For a sample n [N]I we denote the collection of embeddings by v(n) = (v(1, n1), . . . , v(I, n I)). Using the convention (5) we can then write ˆf(v(n)) = ˆf(n). Moreover, we define v(i) RDN as the concatenation of the vectors v(i, 1), . . . , v(i, n), i.e., the combined embedding for the i-th slot and v RDNI as the concatenation of the vectors v(i) for 1 i I, i.e., all token embeddings concatenated. We consider the regularized loss given by ˆRλ( v) = N S ˆL( v) + λ 2 | v|2. (7) Note that the scaling by N/S (instead of usual 1/S) is natural because every token embedding v(i, n) occurs in approximately S/N of the samples and so the scaling ensures that the gradient of ˆRλ with respect to the token embedding v(i, n) is of order one. Now, we consider the continuous gradient descent of the loss with respect to the embeddings, i.e., we consider (omitting the time variable from the notation) v(i, n) = d dv(i, n) ˆRλ( v) = N S d dv(i, n) ˆL( v) λv(i, n). (8) This introduces a time dynamics on the token embeddings. We indicate the time dependence by v(i, n, t) but we drop t if not necessary. Our main goal is to understand which conditions ensure that the token embeddings v(i, n, t) and v(i, n , t) converge to each other if Πi(n) = Πi(n ) as t . To investigate this we define the center of the class embeddings by w(i, k, t) = 1 |Π 1 i (k)| n Π 1 i (k) v(i, n, t) (9) and we consider the deviations from the class centers given by δ(i, n, t) = v(i, n, t) w(i, Πi(n), t). (10) Thus the vectors δ(i, n, t) capture whether we recover the cluster structure, in particular if all norms |δ(i, n, t)| are small then we essentially recovered the hidden structure. Therefore, we define δmax(t) = max i [I] max n [N] |δ(i, n, t)|. (11) Similarly, to the notation introduced before we consider for k [K]I the vector w(k) = (w(1, k1), . . . , w(I, k I)). As in (5) we abuse notation and write ˆg(k) = ˆf(w(k)) = ˆf(w(1, k1), . . . , w(I, k I)). (12) Similar to v(i) and v we introduce w(i) as the concatenation of (w(k, i))k [K] and w as the concatenation of w(i). Assumptions. Our first result for the gradient dynamics states that clusters are stable under the dynamics if the initial loss is sufficiently small. More precisely, this means that we assume that v(i, n) and v(i, n ) are close initially whenever Πi(n) = Πi(n ), i.e., δmax(0) is small. In addition, we assume that |ˆg(k) g(k)| is small. To capture this we define rmax(t) = max k [K]I |ˆg(w(k, t)) g(k)|. (13) Then the result shows that δmax stays small for all times if S N 2 under mild additional assumptions. In other words, if we start from the correctly learned substructures and ˆg(k) g(k) for all k [K]I then this remains true for all times. Note that while we phrase smallness as an assumption on the mean embeddings w(i, k) this is generally a consequence of δmax small and a small empirical loss ˆL( v). Let us now state the required assumptions. Assumption 1. We assume that the map Π : [N]I [K]I is approximately balanced which means that for all i [I], k [K] N 2K Π 1 i (k) 2N This assumption ensures that clusters are of approximately equal size. We have already seen in Section 3 that different cluster sizes increase the sample complexity of learning Π. Assumption 2. We assume that there is a convex set Ω RD and a constant M 1 such that the following bound holds sup v ΩI sup i1,i2,i3 [DI] max | ˆf(v)|, | i1 ˆf(v)|, | i1 ˆf(v)|, | i1 i2 i3 ˆf(v)| M = Here it is convenient to introduce M = 16M 2 so that later certain errors in a Taylor approximation are bounded by M. We also assume that max k [K]I |g(k)| M = This is a rather mild assumption. For C3 functions ˆf and Ωbounded this is always true. The next assumption entails a rigidity of approximate minimizers of the loss. Assumption 3. We assume that for all embeddings w(i, k) RD for i [I], k [K] that satisfy rmax = max k [K]I | ˆf(w(k)) g(k)| 1 (17) ω0 = min k,i λmin k [K]I,ki=k w(i,k) ˆf(w(k)) w(i,k) ˆf(w(k)) holds for some positive constant ω0. Of course, the bound rmax 1 could be replaced by any other constant. Note that this condition can only hold if D KI 1, i.e., the latent space dimension cannot be too large. The high level intuition of this assumption is essentially that (at least if P k( ˆf(w(k)) g(k))2 is small) there is no direction v RD such that v w(i,k) ˆf(w(k)) 0 for all k such that ki = k, i.e., we cannot move one single embedding without changing the output ˆf(w(k)) for at least one k. If this condition does not hold, then we cannot guarantee that P k( ˆf(w(k)) f(k))2 is minimized for a unique w(i, k) (for all other embeddings fixed). This generally prevents concentration of v(i, n). Note that this condition does not ensure that there is a unique minimizer w, in particular there could still be a rotationally invariant family of embeddings w(i, k) such that ˆf(w(k)) = g(k) for all k [K]I. Finally, we need a further mild assumption that ensures that mean token embeddings w(i, k) stay bounded in some set if the loss is small. This can be achieved, e.g., if ˆf if |w(k)| . Assumption 4. We assume that for all collections of mean embeddings w(i, k) RD for i [I], k [K] that satisfy rmax = max k [K]I | ˆf(w(k)) g(k)| 1 (19) there is a convex set Ω0 Rd such that w(i, k) Ω0 for all i [I], k [K]. Again, the right-hand side of the bound rmax 1 could be replaced by any other constant. Results. The first stability theorem can then be stated as follows. Theorem 4. Let Π : [N]I [K]I be approximately balanced as stated in Assumption 1 Assume that the functions g : [K]I R and ˆf : RID R satisfy Assumption 3 for some ω0 > 0 and Assumption 4 for some convex set Ω0. Assume that Assumption 2 holds for some M and the set Ω= Ω0 + B2(0). Then there are constants c1, C2, C3, C4 > 0 depending on I, M, D, and ω0 such that for all initial embeddings v(i, n, t = 0) RD for i [I] and n [N] satisfying δmax(0) C2K 3I/2, rmax(0) C3K 3I/2 (20) and sample size S c1 max K3IN 2 ln2(N), N ln(N)K9I/2 (21) the following holds with probability at least 1 S 1 over the randomness of the dataset. When considering the gradient dynamics of the embeddings given by (8) the bound R = lim sup t rmax(t) 1 (22) holds and moreover lim sup t δmax(t) C4K3I/2 r In particular δmax(t) 0 if rmax(t) 0, i.e., all token embeddings for one fixed class converge to the same point. This result shows that for order N 2 ln(N) samples and initialization sufficiently close to a global minimum the cluster structure remains stable. We do not provide conditions that ensure rmax(t) 0 which guarantees convergence to 0 loss and perfect recovery of the clusters. Next we note that we cannot expect the clustering to be stable in general even if δmax(0) is arbitrarily small if initialization is not close to a minimum. This is true even in the simplest case where I = K = 1, i.e, we consider gradient descent for a single function value. Then gradient descent does not necessarily converge to a global minimum and close by points do not necessarily stay close because gradient descent is not well posed. Let us clarify this by an example. 4 2 0 2 4 6 0 50 100 150 200 250 Gradient steps Average distance all embeddings Figure 2: Simulation of the setting in Theorem 5 with N = 1000, K = I = 3, D = 2, λ = 0, S = 100.000. (left) trajectories of 50 randomly sampled tokens from 6 different classes. (right) Average distance of token embeddings within a class for different classes (colored) and average distance between all pairs of embeddings (black). Example 1. Assume I = K = D = 1, N > 1, ˆf(x) = x2 + x3 and f(n) = 2. Consider the dataset D = {1, . . . , n}. Assume that v(1, n, t = 0) N(0, σ2) for any σ2 > 0. Then the gradient dynamics introduced in (8) reads v(1, n, t) = ˆf (v(1, n, t))( ˆf(v(1, n, t)) f(n)) 4v(1, n, t) if |v(1, n, t)| is small. (24) We find that v(1, n, t) 2/3 (which is a local minimum of ˆf as t if v(1, n, t = 0) > 0. On the other hand v(1, n, t) 1 if v(1, n, t = 0) < 0. So in this case δmax(0) = O(σ2p ln(N)) but δmax(t) 5/3 as t . Slight modifications show that also δmax is possible. The previous example shows that without additional assumption, we cannot expect to recover the structure of the data. Therefore, we impose additional restrictions on the function ˆf. As apparent from Example 1 and also from the bound in Lemma 1 in Appendix E, it is the curvature of the function ˆf that can push token embeddings of the same class apart. We therefore consider the function class of slow-wise linear functions defined as follows. Definition 1. We call ˆf slot-wise linear if for every v = (v(1), . . . , v(I)) RD I and any i [I], α, β [D] the relation d dvα(i) d dvβ(i) ˆf(v) = 0 (25) Let us denote for v RD by v RD+1 the vector v where we append a 1. The most general slot-wise linear function is then of the form f(v) = T( v(1) v(2) . . . v(I)) (26) where T : R(D+1)I R is a linear map. Note that this class covers linearized attention where the embeddings v(i, n) are split in three separate parts that are used to form key, query, and value vectors. For this function class we can show stronger clustering results. Theorem 5. Let Π be approximately balanced, i.e., assume that Assumption 1 holds. Let ˆf : RID R be a slot-wise linear. Assume that v(i, n, t) follow the gradient dynamics (8) and that v(i, n, t) Ω for all i [I], n [N] and t > 0 for some convex set Ω. Assume that Assumption 2 holds with constant M for the set Ω. Let C1 be the constant from Lemma 1 (which depends on I, D, M). Assume that at initialization |v(i, n, t = 0)| λ 8C1 . (27) Then there are constant C2, C3 0 depending on M, I, D, such that for S max C2 N 2KI 1 λ2 ln2(N/λ), C3 NKI 1 λ4 ln(N/λ) (28) δmax(t) max δmax(0)e λt/8, 4C1K(I 1)/2 holds for all t 0. The high level summary of this result is that for order N 2 ln(N) samples the clusters can be recovered up to an error of order λ 1p ln(S)N/S. Note that the a-priori assumption that the gradient flow is restricted to the set Ωmight appear difficult to guarantee in practice. However, we conjecture that the results extend to gradient dynamics clamped at the boundary. Moreover, in Lemma 2 we prove that the mean embeddings w(i, k, t) stay within a ball of radius R = O( λ 1) for all times. This allows us to prove Theorem 8 which does not require any a-priori bounds on the evolution of v(i, n, t) but comes at the price that the constants C1, C2, and C3 depend on λ so we cannot infer the explicit λ dependence. Let us make an important remark. Remark 1. While we state our results for a fixed function ˆf this function could in principle be time-dependent, e.g., ˆf could be given by a neural network, and we consider gradient descent not only on the token embeddings but also on the network parameters. The only requirement is that the assumptions hold uniformly for all times t. In particular, for Theorem 5 we only need to ensure that the derivatives of ˆf stay uniformly bounded in time. This can, e.g., be guaranteed by clipping the parameters of the slot-wise linear map ˆf. Our results so far show that we can recover the right cluster structure in the sense that the embeddings from the same group cluster. However, this leaves open whether there is any non-trivial dynamics at all, i.e., all embeddings could cluster at the same point. This is in general not the case as can be seen from Corollary 1 which states that in the setting of Theorem 5 and for large times the dynamics of the cluster-means follow the equation w(i, k) = X k [K]I,ki=k 2 w(i,ki)(ˆg(k) g(k))2 λw(i, k) + O where (2K) I αk,i (2/K)I are positive numbers. This shows that w(i, k) follow generally a non-trivial dynamic (and this also justifies the scaling as this expression is of order 1). So in the generic case the cluster structure will be revealed if the numbers of samples is sufficiently large, however, there is no general guarantee that the clusters are well separated. As an illustration of this result we refer to Figure 2 where the clustering of the embeddings becomes apparent. Proof idea and overview. Let use here give a quick summary of the main steps of the proof and where to find them. The first important ingredient is an expansion of the loss gradient. We Taylor expand the loss of a sample v(n) around the point w(Π(n)) to second order with remainder terms, the relevant calculations can be found in Appendix D (see Proposition 1 for the outcome). A second ingredient are concentration bounds for certain datapoint-statistics random variables and random matrices. Those are derived in Section G with the necessary results collected in Theorem 9. Combining the Taylor expansion with the concentration result, we can extract the dominant terms of the expansion (see Appendix E). Moreover, we obtain such an expansion for w(i, k) (see Corollary 1) and thus the displacements δ(i, n) (see Corollary 2). This expansion can then be used to control t|δ(i, n)|2 (see Lemma 1) which is sufficient to control δmax. Discussion of assumptions Let us contemplate the differences and similarities to training token embeddings in neural networks. For the first main result Theorem 4 we make minimal assumptions on ˆf so this could in principle be a neural network applied to the token embeddings. The second result, Theorem 5, is more restrictive but covers subclasses of linearized attention. An important feature of the results is that ˆf itself could be time dependent (see Remark 1). Differences to standard training of neural networks are: We use continuous time gradient descent instead of stochastic gradient descent. This is a frequently used modification and in suitable limits those converge (see, e.g., [20]). We use mean squared error, while sequence modelling usually relies on cross entropy loss. This simplification is frequently used in theoretical analysis, but it is expected that results generally extend to the non-convex cross entropy loss. A more crucial difference is that the embedding space dimension in practice is usually chosen large to provide large representation capacity. Here D has to be rather small to allow a unique optimal solution of the token embeddings that allows us to recover the cluster structure. 5 Conclusion In this paper, we considered a learning problem where we try to recover a partition of tokens from their interaction with other tokens. This can be seen as a toy problem for next token prediction in LLMs, but also more broadly as a problem in scientific inference. We studied this problem from different perspectives, namely from an information-theoretic, complexity-theoretic, and gradient descent based viewpoint. We found that order N ln(N) samples are sufficient to recover the partition for N tokens, while we showed that N 2 ln(N) samples are sufficient for gradient based methods. There are several natural open follow-up questions. First, there are some open questions regarding the tightness of our analysis of the gradient descent. In particular, it is a natural question whether already Ω(N ln(N)) samples are sufficient to control δmax which is the information-theoretic threshold and would be similar to the optimal results for matrix completion [18, 7]). Another interesting question for future research is whether Theorem 5 holds for standard initialization schemes for the token embeddings. Secondly, it is of interest to relax the notion of clustering of embeddings to more general notions that still allow recovering some structure but are also applicable to high dimensional latent spaces and potentially to multiple partitions and relations on the same tokens (e.g., tokens belonging to different clusters). Thirdly, it is a natural question whether this work can be connected more closely to empirical findings. [1] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. [2] Francis Bach. Breaking the curse of dimensionality with convex neural networks. Journal of Machine Learning Research, 18(19):1 53, 2017. [3] William Bialek, Andrea Cavagna, Irene Giardina, Thierry Mora, Edmondo Silvestri, Massimiliano Viale, and Aleksandra M. Walczak. Statistical mechanics for natural flocks of birds. Proceedings of the National Academy of Sciences, 109(13):4786 4791, 2012. [4] Piotr Bojanowski, Edouard Grave, Armand Joulin, and Tomás Mikolov. Enriching word vectors with subword information. Trans. Assoc. Comput. Linguistics, 5:135 146, 2017. [5] Alon Brutzkus and Amir Globerson. Globally optimal gradient descent for a Conv Net with Gaussian inputs. In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 605 614. PMLR, 06 11 Aug 2017. [6] Vivien Cabannes, Berfin Simsek, and Alberto Bietti. Learning associative memories with gradient descent. Co RR, abs/2402.18724, 2024. [7] Emmanuel J. Candes and Terence Tao. The power of convex relaxation: Near-optimal matrix completion. IEEE Transactions on Information Theory, 56(5):2053 2080, 2010. [8] Lénaïc Chizat and Francis Bach. On the global convergence of gradient descent for overparameterized models using optimal transport. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018. [9] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep bidirectional transformers for language understanding. In Jill Burstein, Christy Doran, and Thamar Solorio, editors, Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 1 (Long and Short Papers), pages 4171 4186. Association for Computational Linguistics, 2019. [10] Patrick Esser, Robin Rombach, and Björn Ommer. Taming transformers for high-resolution image synthesis. In IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2021, virtual, June 19-25, 2021, pages 12873 12883. Computer Vision Foundation / IEEE, 2021. [11] Borjan Geshkovski, Cyril Letrouit, Yury Polyanskiy, and Philippe Rigollet. A mathematical perspective on transformers, 2024. [12] Mor Geva, Roei Schuster, Jonathan Berant, and Omer Levy. Transformer feed-forward layers are key-value memories. In Marie-Francine Moens, Xuanjing Huang, Lucia Specia, and Scott Wen-tau Yih, editors, Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, EMNLP 2021, Virtual Event / Punta Cana, Dominican Republic, 7-11 November, 2021, pages 5484 5495. Association for Computational Linguistics, 2021. [13] Irene Giardina. Collective behavior in animal groups: theoretical models and empirical studies. HFSP J, 2(4):205 219, Aug 2008. [14] Carla P. Gomes, Henry Kautz, Ashish Sabharwal, and Bart Selman. Chapter 2 satisfiability solvers. In Frank van Harmelen, Vladimir Lifschitz, and Bruce Porter, editors, Handbook of Knowledge Representation, volume 3 of Foundations of Artificial Intelligence, pages 89 134. Elsevier, 2008. [15] Arthur Jacot, Franck Gabriel, and Clement Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018. [16] Samy Jelassi, Michael Sander, and Yuanzhi Li. Vision transformers provably learn spatial structure. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 37822 37836. Curran Associates, Inc., 2022. [17] Albert Q. Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, Lélio Renard Lavaud, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, and William El Sayed. Mistral 7b, 2023. [18] Raghunandan H. Keshavan, Andrea Montanari, and Sewoong Oh. Matrix completion from a few entries. IEEE Transactions on Information Theory, 56(6):2980 2998, 2010. [19] Yuchen Li, Yuanzhi Li, and Andrej Risteski. How do transformers learn topic structure: Towards a mechanistic understanding. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 19689 19729. PMLR, 23 29 Jul 2023. [20] Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. Co RR, abs/1902.06015, 2019. [21] Tomás Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient estimation of word representations in vector space. In Yoshua Bengio and Yann Le Cun, editors, 1st International Conference on Learning Representations, ICLR 2013, Scottsdale, Arizona, USA, May 2-4, 2013, Workshop Track Proceedings, 2013. [22] Tomás Mikolov, Wen-tau Yih, and Geoffrey Zweig. Linguistic regularities in continuous space word representations. In Lucy Vanderwende, Hal Daumé III, and Katrin Kirchhoff, editors, Human Language Technologies: Conference of the North American Chapter of the Association of Computational Linguistics, Proceedings, June 9-14, 2013, Westin Peachtree Plaza Hotel, Atlanta, Georgia, USA, pages 746 751. The Association for Computational Linguistics, 2013. [23] Open AI. GPT-4 technical report, 2023. [24] P. Romanczuk, M. Bär, W. Ebeling, B. Lindner, and L. Schimansky-Geier. Active brownian particles. The European Physical Journal Special Topics, 202(1):1 162, 2012. [25] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurélien Rodriguez, Armand Joulin, Edouard Grave, and Guillaume Lample. Llama: Open and efficient foundation language models. Co RR, abs/2302.13971, 2023. [26] Joel A. Tropp. User-friendly tail bounds for sums of random matrices. Foundations of Computational Mathematics, 12(4):389 434, 2012. [27] Aäron van den Oord, Oriol Vinyals, and Koray Kavukcuoglu. Neural discrete representation learning. In Isabelle Guyon, Ulrike von Luxburg, Samy Bengio, Hanna M. Wallach, Rob Fergus, S. V. N. Vishwanathan, and Roman Garnett, editors, Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, December 4-9, 2017, Long Beach, CA, USA, pages 6306 6315, 2017. [28] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017. [29] Greg Yang. Wide feedforward or recurrent neural networks of any architecture are gaussian processes. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. [30] Ruiqi Zhang, Spencer Frei, and Peter L. Bartlett. Trained transformers learn linear models in-context. Journal of Machine Learning Research, 25(49):1 55, 2024. Supplementary Material This supplementary material is structured as follows. We first review the notation used in the paper in Appendix A. Then we provide the proofs for the information-theoretic results in Section 3 in Appendix B and the proof of Theorem 3 and a reduction to a constraint satisfaction problem in Appendix C. The proofs of the main results concerning the gradient flow rely on a careful Taylor expansion of the loss gradient. This expansion can be found in Appendix D and bounds for this expansion are derived in Appendix E. Based on these bounds, we can prove our main results in Appendix F. An important ingredient in bounding the Taylor expansion are concentration results for the dataset statistics that can be found in Appendix G. Finally, we review some results on random matrices in Appendix H which are necessary for the concentration bounds and we review the definition of Kronecker products in Appendix I. A Overview of Notation used Let us here collect important notation used throughout the paper and the proofs. General notation. Numbers up to n: [n] = {1, . . . , n} Eigenvalues of a matrix: λi(A) Largest eigenvalue: λmax(A) Operator norm of a matrix: A Notation used in the learning problem. Number of slots: I Number of tokens for each slot: N Number of classes: K Number of samples: S Map defining the partition in subclasses: Π = Π1 . . . ΠI : [N]I [K]I Function on classes: g : [K]N R Induced function on tokens: f : [N]I R, f = g Π Notation used in the gradient descent analysis. Dimension of latent space: D Embedding for token n of slot i: v(i, n) Token embeddings for a sample n: v(n) = (v(1, n1), . . . , v(I, n I)) Token embeddings of a slot i: v(i) = (v(i, n))n [N] All token embeddings: v = ( v(i))i [I] Mean of cluster embeddings: w(i, k) Sample version: w(k) = (w(1, k1), . . . , w(I, k I)) Mean token embeddings for slot i: w(i) = (w(i, k))k [K] All mean embeddings: w = ( w(i))i [I] Displacements from cluster center: δ(i, n) = v(i, n) w(i, Πi(n)) Displacements of a sample: δ(n) = (δ(1, n1), . . . , δ(I, n I)) Displacements of all tokens for slot i: δ(i) = (δ(i, n))n [N] Function on token embeddings: ˆf : RDI R, identified with ˆf : [N]I R, ˆf(n) = ˆf(v(n)) Function on classes: ˆg : [K]I R, ˆg(k) = ˆf(w(k)) Regularization: λ B Proofs for Information-Theoretic Results In this section, we provide the proofs and additional results for Section 3. Let us first start to state the general information-theoretic bound that handles the case where |Π 1 i (k)| might be of arbitrary size. Theorem 6. Let g : [K]I R satisfy (2) and there is a projection Π = Π1 . . . ΠI : [N]I [K]I. Assume there is L > 0 such that for every k [K]I we have Π 1(k) N I/L. Moreover, assume for all i [I] and k [K] the bound |Π 1 i (k)| M, i.e., every class has at least M members. Assume we are given S samples {(ns, g Π(ns)) : s [S]} where n U([N]I). If S 2 max KI+1L M max (ln(K)IN, η) , 2ILN max(2 ln(NK)I, η) (31) for some η 2 then with probability at least 1 6e η we can recover Π and g up to permutations of [K]. In other words, for any map Π = Π 1 . . . Π I and g such that g Π(ns) = g Π (ns). There are permutations πi SK such that Πi = πi Π i and the corresponding relation holds for g, g . A few remarks to explain this result are in order. Remark 2. The scaling of S might appear slightly complicated, so let us comment on this. We are mostly interested in the case where N large and K stays bounded. In addition, we are primarily interested in the regime where L CKI bounded and M c N/K (which holds for random Π), i.e., the sampling probability of each class k [K]I is of similar size. In this case, the first term in the condition (31) for S stays constant as N . The second term dominates, and we see that we need only O(N ln(N)) samples to identify Π and g. This is the setting studied in Theorem 1 in the main text. The result is essentially tight in this limit (up to the dependence on I) as stated in Theorem 2. Note that as remarked in the main text, the dependence on L cannot be avoided. Indeed, consider the extreme case where Πi(n) = 1 for n = 1 and Πi(n) = 2 for n > 1 and g(k1, . . . , k I) = 1 iff k1 = . . . = k I = 0 and g(k1, . . . , k I) = 1 otherwise. Then we have L = N I and we also need of order N I samples to sample the point (0, 0, . . . , 0) [N]I which is necessary to identify Π. Proof. Let us first introduce some notation. We denote the set of samples D = {n1, . . . , n S}. (32) Consider two partitions P, Q of a set [N]. We denote by PA for A [N] the restriction of a partition to a subset (i.e., {P A : P P}). The partition distance is defined by D(P1, P2) = min{|Ac| : PA 1 = PA 2 , A [N]}, (33) i.e., the minimal number of elements that need to be removed such that the partitions agree. We call Π compatible with the datapoints D if there is a g such that g Π = g Π on the data. Now the general strategy is to consider any other candidate map Π and upper bound the probability that is compatible, i.e., for some function g the functions f = g Π and f agree on the data. We will then conclude by applying the union bound over maps Π . Thus, we need to prove an upper bound on the number of partitions with partition distance at most and a lower bound on the error probability for a given . Let us start with the latter. Denote by Pi = {Π 1 i (k) : k [K]} (34) the partition generated by Πi. Consider any other map Π = Π 1 . . . Π I : [N]I [K]I with corresponding partitions P i. We define for any such Π the quantity (Π ) = max i [I] D(Pi, P i) (35) Assume now that = (Π ) = D(P1, P 1) M and in particular D(Pi, P i) . Let P1 = {P1, . . . , PK} where Pk = Π 1 1 (k). Let A be a set of maximal size such that PA 1 = P A 1 . After relabeling we can assume that P 1 = {P 1, . . . , P K} and Pk A = P k A. By composing Π 1 with a permutation we can also assume that P k = Π 1 1 (k). Moreover, |P k A| = |Pk A| |Pk| |Ac| |Pk| |Pk|/2 = |Pk|/2. (37) The same applies to all partitions Pi and P i. Next, we claim that if f = g Π and f = g Π agree on all samples then with probability KIe S 2I L over the randomness of the samples g = g . Define Ek := Π 1(k) Π 1(k). Using (37) and the assumption on Π we have that for k [K]N |Ek| := |Π 1(k) Π 1(k)| 2 I|Π 1(k)| N I The probability that none of the S samples is in Ek can then be bounded as follows P(D Ek = ) = 1 |Ek| S e S 2I L . (39) Applying the union bound over k we find that PD(Π is consistent with D for some g = g) KIe S 2I L . (40) Now we bound the probability that g = g is compatible, i.e., g Π(ns) = g Π (ns) for all s [S]. Let now n1 [N]. Denote by En1 [N]I the set of all vectors n = (n1, n2, . . . , n I) such that g Π(n) = g Π (n) (41) We now lower bound the size of En1 under the assumption that there is k [K] such that n1 P k Ac (where A is as above a set such that PA 1 = P A 1 ). Then Π 1(n1) = k and (by minimality of Ac) we find n1 / Pk and thus Π1n1 = k = k for some k. By assumption, we can find k2, . . . , k I such that g(k, k2, . . . , k I) = g( k, k2, . . . , k I). Consider now any vector n = (n1, n2, . . . , n I) such that ni Π 1 i (ki) Π 1 i (ki). Then, for any such n g Π(n) = g( k, k2, . . . , k I) = g(k, k2, . . . , k I) = g Π (n) (42) we conclude that all such n are in En1 By assumption we have [N] i=2 (Π 1 i (ki) Π 1 i (ki)) This implies that |En1| N I 1 and therefore P(En1) 1 2ILN . (45) The sets En1 are disjoint so we conclude that E = n1 [N]En1 satisfies n1 [N] P(En1) |Ac| 1 2ILN = 2ILN . (46) Now we can upper bound the probability that D is compatible with Π and g by PD(D is compatible with Π , g) PD(E D = ) 1 2ILN S e S 2I LN . (47) Clearly the same reasoning applies to any other index instead of 1. The next step is to upper bound the number of such candidate partitions. The number of partitions such that D(P1, P 1) can be bounded by (NK) , i.e., we times select one of N tokens and assign it to another (or the same) class. This bound can be applied to all indices 1 i I and by the union bound, we find that |{(P 1, . . . , P I) : max D(P i, Pi) }| (NK)I . (48) Thus we can bound the the probability that any of the maps Π (there are more such maps than partitions because there is a label assigned to each class, but consistency with the data depends only on the underlying permutation) with 0 < (Π ) M/2 is consistent with the data is bounded by PD( Π with (Π ) M/2 consistent with D) I =1 (NK)I e S 2I LN + KIe S 2I L . (49) Here the first term corresponds to the upper bound for the probability that (after permutation) g = g is compatible with the data and the second term bounds the probability that any other function is compatible with the data and I accounts for the fact that any i [I] can be arg maxi D(Pi, P i). We now consider the remaining case that there is an index i [I] such that D(Pi, P i) M/2 where we can assume w.l.o.g. that i = 1. As above, let P1 = {P1, . . . , PK} and P 1 = {P 1, . . . , P K} such that Pk A = P k A and A is a set of maximal size. We now claim that we can find an index k = k such that 2K and |P k P k| M Indeed, first assume that there is an index k such that |P k Pk| M/K. Since |P k| M there is another index k such that |P k Pk| M/K. By maximality of A we have Pk A = P k A = Pk P k and moreover |P k Pk| + |P k P k| |P k Pk| + |P k P k| (51) because otherwise exchanging P k and P k would allow picking a larger A. By our assumption, we conclude from here that |P k P k| |Pk P k| + |P k P k| |Pk P k| M Thus, in this case |P k P k| M/(2K) and |P k Pk| M/K and (50) holds. Assume now to the contrary that there is no index k such that |P k Pk| < M/K. The assumption |Ac| M/2 implies that there is k such that |Ac P k| M/(2K) and by minimality of Ac we find Ac P k Pk = and thus there is k = k such that |P k P k| = |Ac P k P k| M/(2K2). This finishes the proof of (50). We now fix k and k such that |Pk P k| M/(2K2) and |P k P k| M/(2K2). Then, by assumption, we can find k2, . . . , k I [K] such that k = (k, k2, . . . , k I) and k = ( k, k2, . . . , k I) satisfy g(k) = g( k). Moreover |Π 1(k)| N I/L. Now for every ki there is k i such that |Π 1 i (ki) Π 1 i (k i)| K 1|Π 1 i (ki)|. Thus, there is k [K]I such that Π 1 i (k i) Π 1 i (ki) K I+1 i=2 Π 1 i (ki) KI 1L. (53) Define Ai = Π 1 i (k i) Π 1 i (ki) and A1 = Pk P k and A1 = P k P k. Define A = A1 A2 . . . AI and A = A1 A2 . . . AI. Note that by (50) |A|, | A| N I M 2KI+1L. (54) Clearly Π is constant (and equal to (k, k 2, . . . , k I)) on A A and therefore also g Π is constant for any g . On the other hand, Π(A) = {k} is constant and Π( A) = k is constant but g Π(n) = g(k) = g( k) = g Π( n) (55) for n A and n A. Now Π can only be consistent with the data (for any g ) if there is no sample in A or no sample in A, i.e., and thus PD((Π , g ) consistent with D for a g ) PD(D A = or D A = ) PD(D A = ) + PD(D A = ) (56) where we used the union bound in the last step. Then we find using (54) and (56) PD ((Π , g ) consistent with D for a g ) 2 1 M 2KI+1L S 2e SM 2KI+1L . (57) Finally, we can bound the number of partitions by the number of maps Πi : [N] [K] which implies that |{P1, . . . , PI : Pi partition of [N] in at most K classes}| (KN)I = KNI. (58) So we get the upper bound PD( Π such that (Π ) M/2 and Π compatible with D) 2IKNIe MS KI+1L . (59) Combining (49) and (59) and using the union bound we find PD((Π, g) not identifiable up to permutations) PD( Π with (Π ) > 0 is compatible with D) 2KNIIe MS KI+1L + I =1 (NK)I e S 2I LN + KIe S 2I L . Note that clearly the assumptions imply M N/K < N and thus e S/(2IL) e S /(2ILN) for M/2. We then find (using the simple bound IKI (2K)I (NK)I that =1 (NK)I e S 2I LN + KIe S 2I L 2 =1 e(2 ln(NK)I S 2I LN ) 4e η (61) S max(4 ln(NK)I2ILN, 2I+1ηLN) (62) and η > 2 where we bound the geometric sum by twice its largest term. The first summand in (60) can be bounded by 2IKNIe MS KI+1L = 2e2 ln(K)NI MS KI+1L 2e η (63) S max 4 ln(K)IKI+1LN M , 2ηKI+1L The proof of Theorem 1 is now a direct consequence of the previous result because when assuming that a uniformly random map Π is chosen we can estimate the quantities M and L in the previous theorem with high probability. Proof of Theorem 1. We observe that |π 1 i (ki)| Bin(N, K 1). Applying a Chernoff bound on the tail of the binomial variable we obtain P |π 1 i (ki)| < N By the union bound we get P min i [I],k [K] |π 1 i (k)| < N/(2K) KIe N 8K = eln(K)I N 8K e η (66) if N N0(η, K, I). Note that M N/(2K) implies that L (N/M)I (2K)I. Assuming that Π is such that the bounds for M and L hold we can apply Theorem 6 and find that for N N0 sufficiently large (depending on I, K, η) and S 22I+3IKIN ln(N) (67) (we bounded ln(NK) 2 ln(N)) the maps Π and g are identifiable with probability at least 1 e η. Here we used that as N the term indicated above is dominating in (31). Now the union bound over the bad events ends the proof. Proof of Theorem 2. As before, we note D = {n1, . . . , n S}. Assume there is n1 such that Π1(n1) = 1 and such that D {n1} Π 1 2 (1) . . . Π 1 I (1) = . (68) Then Π = Π is compatible with D where Π i = Πi for i 2 and Π1(n) = Π 1(n) for n = n1 and Π 1(n1) = 2 = 1 = Π1(n1) (by assumption on g). Let us denote An1 = {n1} Π 1 2 (1) . . . Π 1 I (1) [N]I. (69) P(ns An1) = 1 NKI 1 . (70) To estimate the probability of the event S n1{An1 D = } we use Poissonization to make the events independent. Consider datasets D whose distribution is generated by first sampling S Poi(2S) and then conditional on S sample a dataset D as before, i.e., ns U([N]I) for s [ S]. Then we find that |An1 D| Poi 2S NKI 1 and those events are independent for n1 = n 1. Thus, we find that P D(|An1 D| = 0) = e 2S NKI 1 . (72) By independence we now find that n1 Π 1 1 (1) {An1 D = } n1 Π 1 1 (1) P D An1 D = = 1 e 2S NKI 1 N/K 2 ln(N/K) N/K = where we used the upper bound for S and (1 x) e x. Assume Ds U([N]I)s and define n1 Π 1 1 (1) {An1 Ds = } Note that p S is an upper bound on the probability that Π is identifiable as explained above. We have shown that E(p Poi(2S)) e Clearly ps is decreasing in s. This implies that N/K E(p Poi(2S)) p S P(Poi(2S) S). (76) A Chernoff bound for the Poisson distribution reads P(Poi(λ) x) (eλ)xe λ which implies with λ = 2S and x = S that P(Poi(2S) S) (2Se)Se 2S We thus conclude that for S 3 (implying 1 (2/e)S 1/2). C Proofs of Complexity-Theoretic Analysis and Constraint Satisfaction Reduction In this section, we provide the proof of Theorem 3 and a general reduction to a constraint satisfaction problem. We start with the proof of Theorem 3. An important property we will use frequently in the proof is that T has the property that for all values x1 and x2 there is x3(x1, x2) such that T(x1, x2, x3(x1, x2)) = 0. Indeed, x3(x1, x2) = 0 except for x1 = x2 = 1 where x3(1, 1) = 0 has this property. Proof. We reduce the problem to 3SAT. We consider an arbitrary formula ϕ(x) = Ci(x) (80) on ℓvariables with m clauses Ci. We denote f = g Π. To improve readability, we consider a symbol set S instead of [N] which is used for all 3 slots. We proceed in three steps. First we show that we can write down a set of equations that ensures that two symbols s0, s1 S satisfy Πi(sj) = j for all i (if f = g Π). Suppose the following relations hold f(s0, s0, s1) = f(s0, s1, s0) = f(s1, s0, s0) = 0, f(s0, s1, s1) = f(s1, s0, s1) = f(s1, s1, s0) = 1. (81) Then it is easy to conclude that Πi(sj) = j holds. Indeed, suppose that Π1(s1) = 0. Then the lower part of the previous display combined with the definition of g imply Π2(s0) = Π3(s1) = Π2(s1) = Π3(s0) = 1. But this is a contradiction to f(s1, s0, s0) = 0. Thus Π1(s1) = 1 and the same reasoning implies Πi(s1) = 1. Then the second equation in (81) directly implies Πi(s0) = 0. We now consider a symbol set S = {s0, s1} X X C C T T F (82) X = {X1, . . . , Xℓ} (83) will encode the variables in the formula ϕ and C = {C1,1, C1,2, C2,1, C2,2, . . . , Cm,1, Cm,2} (84) are auxiliary variables for the clauses. In addition, we need further auxiliary variables X = { X1, . . . , Xℓ}, C = { C1,1, C1,2, C2,1, C2,2, . . . , Cm,1, Cm,2}, and O = {o1, . . . , om}. (85) We now add a set of equations that will ensure that for some value xi {0, 1} Π1(Xi) = Π2(Xi) = Π3(Xi) = xi Π1( Xi) = Π2( Xi) = Π3( Xi) = 1 xi. (86) This can be achieved by adding the following relations for all i f(Xi, Xi, s1) = f( Xi, Xi, s1) = f(s1, Xi, Xi) = 1, f(s1, Xi, Xi) = f(Xi, s1, Xi) = f( Xi, s1, Xi) = 1. (87) Note that the these relations indeed imply that Πa(Xi) = Πb( Xi) for a = b, a, b {1, 2, 3}. This implies then that (86) holds. We add the similar relations for Ci,k which again ensure that Πj(Ci,k) = Πj (Ci,k) = Πj( Ci,k). (88) Now we encode the clauses of the formula. Consider a clause Cl of the form xi1 xi2 xi3. Then we add the relations f( Xi1, Xi2, Cl,1) = 0 f( Xi2, Xi3, Cl,2) = 0 f( Cl,1, Cl,2, ol) = 1. For a given choice of Π we set xi = Π1(Xi). We now claim that the relations in the last display can hold if and only if xi1 xi2 xi3 evaluates to true. Suppose the formula evaluates to false. Then Πj(Xi1) = Πj(Xi2) = Πj(Xi3) = 0 (90) and therefore Πj( Xi1) = Πj( Xi2) = Πj( Xi3) = 1. (91) Then relation (89) implies Π3(Cl,i) = 1 (92) for i = 1, 2. Using (88) we find that Π1( Cl,1) = Π2( Cl,2) = 0 which implies f( Cl,1, Cl,2, ol) = f(0, 0, ol) = 0 (93) for any Π3(ol). Therefore, the equations cannot all hold. Suppose to the contrary that the formula evaluates to true. Let us first assume that Π1(Xi1) = 1. Then Π1( Xi1) = 0 and we can set Π3(Cl,1) = 0 which ensures that the first equation is satisfied (for any Π2( Xi2)). By definition of g we can also find a value cl,2 such that for cl,2 = Π3(Cl,2) the relation f( Xi2, Xi3, Cl,2) = f(xi2, xi3, cl,2) = 0 holds. Now Π1( Cl,1) = 1 Π3(Cl,1) = 1 ensures that for any value of cl,2 = Π2( Cl,2) we can choose Π3(ol) such that the relation f( Cl,1, Cl,2, ol) = f(1, cl,2, ol) = 1 holds. The same reasoning applies for Π1(Xi3) = 1 and a similar argument applies if Π1(Xi2) = 1. For clauses containing negations the same construction works except that Xij has to be replaced by Xij for the negated variables in equation (89). Putting everything together, we have shown that for a given formula ϕ there is a set of relations of the form f(ns) = ts where ns S3 given by (81), (87) (plus similar equations for Cl,i), and (89) which has the following properties. For an assignment x = (x1, . . . , xn) {0, 1}n such that ϕ(x) = 1 evaluates to true, there is a map Π such that for all s the relations ts = f(ns) hold and where Π1(Xi) = xi. On the other hand, if for some Π all the relations ts = f(ns) hold, then the Boolean variables xi = Π(Xi) will satisfy the formula ϕ. This ends the proof. We now show how the problem of finding Π and g can be generally expressed as a constraint satisfaction problem. Recall that we want to find (if they exist) for a given K a projection map Π : [N]I [K]I and g : KI R such that for all given samples (ns, ts) the relation ts = g Π(ns) holds. The general strategy is to introduce Boolean variables encoding the maps Π and g and then express all conditions as suitable constraints for these variables. We first introduce Boolean variables ti kn for i [I], k [K], and n [N] which are 1 if the token n is in cluster k, i.e., Πi(n) = k and 0 otherwise. Then the expression ti 1n ti 2n . . . ti kn (94) for i [I] and n [N] is true if and only if n is assigned to to at least one cluster. In addition, we consider variables rkv for every k [K]I and v Im(f) in the (finite) image of f. These variables shall encode whether g(k) = v is true or not. Then the constraints rkv rkv (95) for v = v and all k ensure that each cluster is assigned at most one value v. Finally, we add for every datapoint (ns, f(ns)) and every k the constraint t1 k1ns 1 . . . t I k Ins I rkf(ns). (96) This ensures that if Π(ns) = k then this cluster must be assigned the value f(ns). We then consider the constraint satisfaction problem consisting of the V of all the conditions in (94), (95), and (96). Then any satisfying assignment gives rise to a map Π and g such that g Π(n2) = ts. Indeed, we set g(k) = v if skv = 1 if such a v exists and arbitrarily otherwise. Note that for every satisfying assignment and every k there is at most one such v so g is well-defined. Moreover, we set Πi(n) = k for any k such that ti nk = 1 and at least one such k exists. On the other hand, we can easily construct a satisfying assignment given Π and g so that there are no solutions Π and g if the constraint satisfaction problem has no solution. Note that we did not ensure that each token is assigned only to a single cluster but this can be achieved in post-processing or by adding additional constraints (such as ti k1n ti k1n). D Taylor Expansion of the Loss Gradient The goal of this and the following section is to lay the groundwork for the proof of the main results. The general strategy of our proofs is to Taylor expand the loss of each term around the mean token embedding w(i, Πi(ni)) to first order plus a remainder term. We can then extract the dominating terms of these expansions using concentration results for the datapoint statistics. It turns out that the linearized dynamics has favorable properties, while the remainder terms can be bounded. In this section, we derive the expansion of the loss gradient while the required bounds, in particular Theorem 7 and Lemma 1, are derived in the next section, Appendix E (they rely on concentration results which are deferred to Appendix G). As pointed out above, the goal of this section is to Taylor expand the sample loss for one sample, where we expand the loss around the point w(Π(n)). Let us first introduce some notation. Recall that we denoted by v(n) = (v(1, n1), . . . , v(I, n I)) and by v(i) the concatenation of the embeddings (v(i, n))n [N]. We define similarly δ(n) = (δ(1, n1), . . . , δ(I, n I)) (97) and we denote by δ(i) RDN the concatenation of δ(i, 1), . . . , δ(i, n), i.e., δ(i) = (δ(i, n))n [N] (98) and δ RIDN as the concatenation of δ(1), . . . , δ(I). Consider a data-point n with Π(n) = k. We now consider the derivative of the mean squared error of this term, i.e., 2( ˆf(n) f(n))2 = 1 2( ˆf(n) g(k))2. (99) We use Greek-indices for the latent space dimension. Fix an i [I] as the slot with respect to which we take the derivative. We then find for α [D] d dvα(i, ni) 1 2 ˆf(v(n)) f(k) 2 = d dvα(i, ni) 1 2 ˆf(v(n)) f(k) 2 = d dvα(i, ni) ˆf(v(n)) ˆf(v(n)) f(k) =: h(v(n)). Here we introduced the shorthand h(v(n)) for this function which also depends on i, k, and α. Now we estimate this function by Taylor expanding it around the point w(k) = (w(1, k1), . . . , w(I, k I))). (101) Then we find the following expansion to second order h(v(n)) = h(w(k)) + X d dwβ(j, kj)h(w(k))δβ(j, nj) β1β2 [D] Rn,i,α (j1,β1),(j2,β2)(w(k))δβ1(j1, nj1)δβ2(j2, nj2) (102) where R(j1,β1),(j2,β2) denotes the remainder terms. We remark that if we assume that v(i, n) Ωfor some convex set Ωand all i [I], n [N] then (by convexity of Ω) also w(i, ki) Ωand |Rn,i,α (j1,β1),(j2,β2)| max v ΩI 1 2 d dvβ1(j1, nj1) d dvβ2(j2, nj2)h(v) . (103) Let us now (abusing notation again, it will be clear from context which function we refer to) write 2( ˆf(w(k)) g(k))2. (104) We introduce some more quantities to get a concise representation of the Taylor expansion. We define the matrices Di,j L(k) RD D containing the derivatives of a function L with respect to w(j1, nj1) and w(j2, nj2), i.e., we consider (Dj1,j2L(k))α,β = d dwα(j1, kj1) d dwβ(j2, kj2)L(k). (105) Similarly we define the vector Dj L(k) RD by (Dj L(k))α = d dwα(j, kj)L(k). (106) For the diagonal entries Dj,j we need a more fine-grained decomposition. Note that by the product rule we have d dwα(j, kj) d dwβ(j, kj)L(k) = d dwα(j, kj) d dwβ(j, kj) ˆf(w(k)) ˆf(w(k)) g(k) + d dwα(i, ki) ˆf(w(k) d dwβ(i, ki) ˆf(w(k)) . (107) Using the notations we introduced we can therefore write (recall ˆg(k) = ˆf(w(k))) Dj,j L(k)) = (Dj,jˆg(k)) (ˆg(k) g(k)) + Djˆg(k) Djˆg(k). (108) Collecting finally all remainder terms as Rα(i, n) = X β1β2 [D] Rn,i,α (j1,β1),(j2,β2)(w(k))δβ1(j1, nj1)δβ2(j2, nj2) (109) we can thus summarize the Taylor expansion result as follows Di L(n) = Di L(k) + ((Di,iˆg(k)) (ˆg(k) g(k)) + Diˆg(k) Diˆg(k)) δ(i, ni) j =i (Di,j L(k))δ(j, nj) + R(i, n) (110) Based on the expansion (110) we now want to get an expression for the gradient of the total loss on the entire dataset. We decompose the dataset as follows Dk,n,i = {ns D| Π(ns) = k, ns i = n}. (111) Note that if Πi(n) = ki then Dk,n,i = . We also define similarly Dn,i = {ns D| ns i = n}. (112) We find the following expansion d dv(i, n) ˆL( v) = d dv(i, n) s=1 L(ns) = X d dv(i, n)L(n) n Dk,n,i Di L(k) + ((Di,iˆg(k)) (ˆg(k) g(k)) + Diˆg(k) Diˆg(k)) δ(i, ni) j =i (Di,j L(k))δ(j, nj) + X n Dn,i R(i, n) We now rewrite or bound the four summands. First, we relate those expressions to the datapoint statistics matrices. We define Bk,i n = |Dk,n,i| = |{n D : Π(n) = k, ni = n}|. (114) Ak,i,j n,n = |{n Dk,n,i|nj = n }| = |{n D : Π(n) = k, ni = n, nj = n }|. (115) Then the first term equals X n Dk,n,i Di L(k) = X k [K]I Bk,i n Di L(k). (116) The second term is similarly given by X n Dk,n,i ((Di,iˆg(k)) (ˆg(k) g(k)) + Diˆg(k) Diˆg(k)) δ(i, ni) k [K]I Bk,i n (Di,iˆg(k)) (ˆg(k) g(k)) + Diˆg(k) Diˆg(k) δ(i, n) (117) The third term is given by X j =i (Di,j L(k))δ(j, nj) = X n [N] Ak,i,j n,n (Di,j L(k))δ(j, n ) Ak,i,j Di,j L(k) δ(j) where δ was introduced in (98). Let us summarize those findings as a proposition. Proposition 1. Assume that the bound (15) holds for some Ωand v(i, n) Ωfor all i [I], n [N]. Then the following expansion holds d dv(i, n) ˆL( v) = X k [K]I Bk,i n Di L(k) k [K]I Bk,i n ((Di,iˆg(k)) (ˆg(k) g(k)) + Diˆg(k) Diˆg(k)) δ(i, n) Ak,i,j Di,j L(k) δ(j) n Dn,i R(i, n) Proof. The result follows from the calculations above. E Bounds for the Loss Gradient The goal of this section is to extract the asymptotically (with high probability) dominating terms of the expansion from the previous section as S . The crucial observation is that the appearing averages can be essentially replaced by their expectation. This is a consequence of concentration properties of the datapoint statistics. These properties will be derived in Appendix G where they are collected in Theorem 9. Here, we just prove the result conditional on the bound in Theorem 9. Recall that we defined δmax = maxi [I],n [N] |δ(i, n)| in (11) as the maximal deviation norm and rmax = maxk [K]I |ˆg(k) g(k)| in (13) as the maximal residual. Finally, we introduce the notation p(k, i) = |Π 1(k)| N I|Π 1 i (ki)|. (120) Note that Bk,i n Bin(S, p(k, i)) if Πi(n) = ki and Bk,i n = 0 otherwise. In particular, EBk,i n = S p(k, i). Note that if Π is approximately balanced if the following bounds hold Np(k, i) = N |Π 1(k)| N I|Π 1 i (ki)| = Y |Π 1 j (kj)| This implies 1 I Np(k, i) 2 Theorem 7. Assume that the bounds in Theorem 9 hold and assume ˆf satisfies Assumption 2 for some set Ωand v(i, n) Ωfor i [I] and [N]. Then we obtain the following expansion for the loss gradient N S d dv(i, n) ˆL( v) k [K]I:ki=Πi(n) Np(k, i)Di L(k) k [K]I:ki=Πi(n) Np(k, i) (ˆg(k) g(k))Di,iˆg(k) + Diˆg(k) Diˆg(k) δ(i, n) + EF (i, n) + ES,1(i, n) + ES,2(i, n) + EI(i, n) + ET (i, n). Here the error terms are bounded by |EF α (i, n)| 4M min(1, rmax) (2K)(I 1)/2 r |ES,1 α (i, n)| 4DM min(1, rmax) (2K)(I 1)/2 r S δmax, (125) |ES,2 α (i, n)| 4DM (2K)(I 1)/2 r S δmax, (126) |EI(i, n)| 6DM (2K)(I 1)/2 ln(S) S δmax, (127) ET α (i, n) MI2Dδ2 max Proof. The proof is a bit technical and essentially combines the assumptions and the concentration results in a straightforward fashion. First we remark that for n Π 1 i (k) we have as stated before that EDBk,i n = S p(k, i) (129) and Bk,i n = 0 otherwise. Then Proposition 1 implies that (123) holds with the definitions of the error terms given below. We now define and bound the error terms in the decomposition. Note that we have |(Dj L(k))α| = d dvα(j, kj)L(k) = d dvα(j, kj) ˆg(k)(ˆg(k) g(k)) M min(1, rmax). (130) where we either bound both terms using Assumption 2 or the second term by rmax and the first one by Assumption 2 together with M M. This then implies that the error term capturing the fluctuations of the occupation statistics EF (i, n) = N Bk,i n EDBk,i n Di L(k) (131) can be bounded by |EF α (i, n)| 4M min(1, rmax) (2K)(I 1)/2 r S C(I, K, M) min(1, rmax) Here we used that there are KI 1 non-vanishing terms in the sum (if Πi(n) = ki we have Bk,i n = EBk,i n = 0). Next, we consider the self interaction error terms ES,1(i, n) = N Bk,i n EBk,i n (ˆg(k) g(k)) (Di,iˆg(k)) δ(i, n), (133) ES,2(i, n) = N Bk,i n EBk,i n (Diˆg(k) Diˆg(k))δ(i, n). (134) For ES,1 we get (similar to above) and using that the operator norm of a D D matrix A is bounded by D max |aij|) the bound |ES,1 α (i, n)| 4DM min(1, rmax) (2K)(I 1)/2 r C(I, K, M, D) min(1, rmax)δmax The error term ES,2 could also be absorbed in other terms later on but since it is not dominant, we just bound it. We can obtain similarly to our treatment of ES,1 the bound |ES,2 α (i, n)| 4DM (2K)(I 1)/2 r C(I, K, M, D)δmax Next, we consider the interaction error term EI(i, n) = N Ak,i,j Di,j L(k) δ(j) n D:(n+1)D (137) We note that E(Ak,i,j n1,n2) = c for some constant c if Πi(n1) = ki and Πj(n2) = kj and 0 otherwise. This implies that if Πi(n) = ki EAk,i,j Di,j L(k) δ(j) n D:(n+1)D = X n Π 1 j (kj) L(k)δ(j, n ) = 0 (138) here we used X n |Π 1 i (k)| δ(i, n) = 0 (139) which follows from the definition (10) in the last step. If Πi(n) = ki then this expression is clearly zero as the corresponding row of Ak,i,j is zero. Thus, we find that |EI(i, n)| = (Ak,i,j EAk,i,j) Di,j L(k) δ(j) S max k,i,j Ak,i,j EAk,i,j Di,j L(k) max j | δ(j)|. Here we used the submultiplicativity of the operator norm of Kronecker-products stated in (281). Now we use | δ(j)| Nδmax, the concentration bound (252), and bound Di,j L(k) DM similar as before to find |EI(i, n)| = 6DM (2K)(I 1)/2 ln(S) S δmax. (141) Finally, we bound the Taylor expansion remainder term ET (i, n) = N n Dn,i R(i, n). (142) Recall that Rα(i, n) = X β1β2 [D] Rn,i,α (j1,β1),(j2,β2)(w(k))δβ1(j1, nj1)δβ2(j2, nj2) (143) |Rn,i,α (j1,β1),(j2,β2)| max v Ω 1 2 d dvβ1(j1, nj1) d dvβ2(j2, nj2)h(v) (144) where h = hn,i,α was introduced in (100). Assuming the bounds (15) and (16) we find |Rn,i,α (j1,β1),(j2,β2)| max v Ω 1 2 d dvβ1(j1, nj1) d dvβ2(j2, nj2)h(v) 4 Indeed, the derivatives of h can be decomposed by the product rule in 4 terms which each can bounded using (15) and (16). Then we can bound the remainder term as follows |Rα(i, n)| M X β1β2 [D] |δβ1(j1, nj1)δβ2(j2, nj2)| MID X j [I] |δ(j, nj)|2. (146) Recalling that Dn,i defined (112) satisfies |Dn,i| = Bi n we obtain using the concentration bound (253) from Lemma 9 ET α (i, n) = N n Dn,i Rα(i, n) S Bi n MI2Dδ2 max MI2Dδ2 max If we consider the gradient descent dynamics we obtain the following expansion for the time derivative w(i, n). Let us from now on absorb the dependence on I and D in generic constants C. Corollary 1. Under the same assumptions as in Theorem 7 we get w(i, k) = X k [K]I,ki=k Np(k, i)Di L(k) λw(i, k) + Ew(i, k) (148) where Ew can be bounded by |Ew α (i, k)| C min(1, rmax)K(I 1)/2 r S + CK(I 1)/2 ln(S) S δmax + Cδ2 max (149) for S N ln(N). Proof. The result follows from Theorem 7, the gradient dynamics (see (8)), and the definition of w(i, n) (see (9)). We also use the relation P n |Π 1 i (k)| δ(i, n) = 0 (already stated in (139)) to conclude that terms involving δ(i, n) cancel. The condition S N ln(N) allows us to bound ln(S)N/S 2 to control ET . Note that here we use that ES,1 and ES,2 can be bounded by EI. Note that we absorbed all terms involving δmax into the dominating term. We therefore get the following bound for δ(i, n). Corollary 2. Under the same assumptions as in Theorem 7 we get δ(i, n) = X k [K]I,ki=n Np(k, i) (ˆg(k) g(k))Di,iˆg(k) + Diˆg(k) Diˆg(g) δ(i, n) λ δ(i, n) + Eδ(i, n). where Eδ can be bounded by |Eδ(i, n)| C min(1, rmax)K(I 1)/2 r S + CK(I 1)/2 ln(S) S δmax + Cδ2 max (151) for S N ln(N). Proof. This follows from Theorem 7, Corollary 2, and the relation δ(i, n) = v(i, n) w(i, Πi(n)). (152) Now we are in the position to control the time evolution of |δ(i, n)|2. We obtain the following relation Lemma 1. Under the same assumptions as in Theorem 7 the following bound holds d dt 1 2|δ(i, n)|2 λ ω (2K)I + 2I sup v1,...,v I Ω max i D2 i,i ˆf(v1, . . . , vn) rmax + C1 min(1, rmax)K(I 1)/2 r + C1K(I 1)/2 ln(S) S δ2 max + C1δ3 max where C1 is a constant depending on D, I, and M, and ω is defined by ω = min k,i λmin k [K]I,ki=k Diˆg(k) Diˆg(k) Proof. We have d dt 1 2|δ(i, n)|2 = δ(i, n) δ(i, n) k [K]I,ki=n Np(k, i)(ˆg(k) g(k))δ(i, n) (Di,iˆg(k)) δ(i, n) k [K]I,ki=n Np(k, i)δ(i, n) Diˆg(k) Diˆg(k)δ(i, n) δ(i, n)Eδ(i, n) λ|δ(i, n)|2. Recall that by (122) we have 1 I Np(k, i) 2 We consider ω as defined in the statement of the lemma where ω 0 follows because we sum positive semi-definite rank one matrices. Then we find using the lower bound on the spectrum k [K]I,ki=n Diˆg(k) Diˆg(k) |δ(i, n)|2λmin k [K]I,ki=n Diˆg(k) Diˆg(k) ω|δ(i, n)|2. Thus we find d dt 1 2|δ(i, n)|2 λ|δ(i, n)|2 ω (2K)I |δ(i, n)|2 + 2I sup v1,...,v I Ω max i D2 i,i ˆf(v1, . . . , vn) rmax|δ(i, n)|2 δ(i, n)Eδ(i, n) (158) where we used for the second to last contribution that we sum over KI 1 terms which is cancelled by the (2/K)I factor. We finally bound the error term by |δ(i, n) Eδ(i, n)| D min(1, rmax)K(I 1)/2 r S δmax + CK(I 1)/2 r S δ2 max + Cδ3 max C min(1, rmax)K(I 1)/2 r S δmax + CK(I 1)/2 r S δ2 max + Cδ3 max. This ends the proof. F Proofs of the Main Results In this section, we prove our main results, Theorem 4 and Theorem 5. In addition, we state and prove Theorem 8 which is a variant of Theorem 5 without a-priori bound on the embeddings. Essentially, the proof strategy for both results is to rely on the groundwork from the previous two sections, in particular on Lemma 1. Indeed, we first verify that the conditions of Lemma 1 hold for all times and then application of this Lemma allows us to control the evolution of δmax for all times. Proof of Theorem 4. The proof proceeds in three steps. First we use the monotonicity of the loss to deduce that rmax(t) can be bounded in terms of rmax(0) and δmax(t) and δmax(0). Then we show that choosing the variables as in the statement of the theorem allows us to bound all the terms in Lemma 1 in Appendix E. Then we apply Lemma 1 and deduce the decay of the maximum of the displacements |δ(i, n, t)|. Here we need to check carefully that the assumptions of the lemma are satisfied for all t. First we assume the conclusions on the concentration properties from Theorem 9 in Appendix G hold which occurs with probability at least 1 S 1 over the randomness of the training data. Let us consider any time T > 0 and we assume that v(i, n, t) Ωfor all i [I], n [N] and 0 τ T. This allows us in particular to apply the Assumption 2 up to time T for embeddings v(i, n). We now implement the first step where we bound rmax(t) for all 0 t T in terms of rmax(0), δmax(t), and δmax(t). The idea is to upper bound the initial sample loss ˆL(t = 0) and lower bound the loss ˆL(t > 0). First we note that by a first order Taylor expansion and using Assumption 2 | ˆf(v(n)) ˆf(w(Π(n)))| X i I M|δ(i, ni)| MIδmax. (160) This implies that n D ( ˆf(v(n)) g(Π(n)))2 | ˆf(w(Π(n))) g(Π(n))| + X i I M|δ(i, ni)| 2S(rmax + IMδmax)2. Next, we derive a lower bound on the loss. Let k be such that | ˆf(k) f(k)| = rmax. By assumption and using concentration bounds (e.g., by summing (254) over n and using the lower bound for S) we find that |Dk| = |{n D : Π(n) = k}| 1 2 S (2K)I . (163) Thus we can lower bound the loss using the same Taylor expansion as before by ˆL( v(t)) 1 4 S (2K)I (max(0, rmax(t) MIδmax(t)))2 . (164) Since gradient descent does not increase the loss the bounds (161) and (164) together imply for t 0 1 2S(rmax(0) + IMδmax(0))2 ˆL( v(0))) ˆL( v(t)) 4 S (2K)I (max(0, rmax(t) MIδmax(t)))2 . (165) This implies rmax(t) 2(2K)I/2(rmax(0) + IMδmax(0)) + IMδmax(t). (166) Then we get the bound rmax(t) 4(2K)I/2(rmax(0) + IM max(δmax(0), δmax(t))). (167) Now we choose C2 and C3 such that the initialization condition (20) becomes δmax(0) min ω0 8C1(2K)I , ω0 64IM 2D(2K)3I/22I , 1 8(2K)I/2IM , 1 =: δ (168) rmax(0) min ω0 64MD(2K)3I/22I , 1 8(2K)I/2 Assume in addition that for all times 0 t T the bound δmax(t) δ holds. Using (167) we then find that for 0 t T rmax(t) 4(2K)I/2 1 64MD(2K)3I/22I + IM ω0 64IM 2D(2K)3I/22I ω0 8MD(2K)I2I . and similarly rmax(t) 4(2K)I/2 1 8(2K)I/2 + IM 1 8(2K)I/2IM and therefore rmax(t) min ω0 8MD(2K)I2I , 1 . (172) The next step is to bound the various error terms appearing in Lemma 1. Using the last display we can bound for all i [I] and n [N] using Assumption 2 for t [0, T] 2I sup v1,...,v I Ω max i D2 i,i ˆf(v1, . . . , vn) rmax(t)|δ(i, n, t)|2 2IMD ω0 8MD(2K)I2I |δ(i, n, t)|2 ω0 8(2K)I |δ(i, n, t)|2. Moreover, we can estimate for 0 t T C1δ3 max(t) C1δ2 max(t) ω0 8C1(2K)I ω0 8(2K)I δ2 max(t). (174) Now we observe that for x = a ln2(a) and a 1 we find (using ln(a) a) the bound ln(x) x ln(a3) a ln(a) = 3 a. (175) Hence we consider S S0 = 242C2 1(2K)3IN 2 ω2 0 ln2 242C2 1(2K)3IN 2 Note that tracking only the dependence on K and N (and using K N) this condition reads S C(2K)3IN 2 ln2(N). (177) Then we find using (175) combined with the observation that ln(S)/ S is decreasing for S e2 for S S0 the bound C1K(I 1)/2 ln(S) S C1K(I 1)/2 ln(S0) 3C1KI/2N 24C1(2K)3I/2Nω 1 0 8 ω0 (2K)I . We note that similarly for x = a ln(a) we get a ln(a) = 2 Thus we consider similarly S S1 = 2 162N(2K)3IC2 1 ω2 0 δ ln 2 162N(2K)3IC2 1 ω2 0 δ where δ was introduced in (168). Note that tracking only N and K this bound becomes S CN ln(N)K9I/2. (181) Then, similar to before, (using monotonicity of ln(S)/S for S e) we find for S S1 C1K(I 1)/2 max(1, rmax) S C1KI/2 max(1, rmax) 16 ω0 (2K)I δ max(1, rmax(t)). (182) After all these preliminary estimates we can now finish the proof. Let us define T R+ 0 { } to be the maximal time such that v(i, n, t) Ωfor all i [I], n [N] and 0 t T and the bound (168) holds for δmax(t) where we for convenience recall δmax(t) δ = min ω0 8C1(2K)I , ω0 64IM 2D(2K)3I/22I , 1 8(2K)I/2IM , 1 . (183) We want to show that T = . We assume that T < and prove by contradiction that this does not hold. By (172) we know that for t T the bound rmax(t) 1 holds and therefore Assumption 4 implies that w(i, k, t = 0) Ω0 and since moreover δmax(t) 1 for t T we find that v(i, n, t = 0) Ω0 + B1(0) Ω. By continuity of v(i, n, t) in t we conclude that there is ε > 0 such that v(i, n, t) Ωfor all 0 t T + ε, i [I], and n [N]. Thus, we can apply Lemma 1 for t = T and then find using Assumption 3 and the bounds (173), (174), (178), and (182) d dt 1 2|δ(i, n, t)|2 λ ω0 (2K)I + 2I sup v1,...,v I Ω max i D2 i,i ˆf(v1, . . . , vn) rmax |δ(i, n, t)|2 + C1 max(1, rmax)K(I 1)/2 r S δmax(t) + C1K(I 1)/2 ln(S) S δ2 max + C1δmax(t)3 |δ(i, n, t)|2 + 1 8|δ(i, n, t)|2 + 1 16 max(1, rmax(t)) δ + 1 8δmax(t)2 + 1 |δ(i, n, t)|2 + 3 8δmax(t)2 + 1 16 max(1, rmax(t)) δ. If |δ(i, n, t)| 3/4δmax we find |δ(i, n, t)|2 δmax(t)2/2 and we conclude that d dt 1 2|δ(i, n, t)|2 ω0 (2K)I δmax(t)( 1 2δmax(t) + 3 8δmax(t) + 1 16 max(1, rmax(t)) δ 8 ω0 (2K)I δmax(t)( δmax(t) + δ/2). (185) Now, if δmax(T) δ/2 we conclude by continuity that there is ε > 0 such that δmax(t) < δ for t [T, T + ε]. On the other hand, if δmax(t) > δ we conclude that d dt 1 2|δ(i, n, T)|2 < 0 for all i, n such that |δ(i, n, t)| 3/4δmax. This in particular implies that δmax(t) < δmax(T) for t [T, T +ε] and some ε > 0 (δmax is non-increasing at T). This is a contradiction, and we conclude that T = . Finally, we prove the decay of δmax. Assume that for t T the bound min(1, rmax(t)) R holds. Note that the function δmax(t) is not necessarily differentiable but its left and right derivative exist (as the maximum of finitely many differentiable functions). Then we obtain, similar to (184), for the right derivative the bound d dt+ 1 2δmax(t)2 1 8 ω0 (2K)I δmax(t)2 + C1RK(I 1)/2 r S δmax(t). (186) δmax(t) > C116(2K)IK(I 1)/2R the previous bound simplifies to d dt+ 1 2δmax(t)2 1 16 ω0 (2K)I δmax(t)2 (188) which implies that for t T by Gronwall s Lemma δmax(t)2 exp 1 16 ω0 (2K)I (t T) δmax(T)2. (189) Thus in finite time we achieve δmax(t) 2C116(2K)IK(I 1)/2R This ends the proof. Note that we also get an exponential rate of convergence. Remark 3. The exponent 9/2 for the lower bound of S in (181) is not tight because we could use that rmax is small, but this provides only a small improvement that does not justify the additional technicalities. Before proving and stating Theorem 8 let us first prove that the weight decay λ > 0 allows us to derive a-priori bounds on the time evolution as stated in the following lemma. Lemma 2. Let Π be approximately balanced, i.e., assume that Assumption 1 holds. Let M1 > 0 be such that sup v1,...,v I B1(0) RD | ˆf(v1, . . . , vi)| M1 2 , max k |g(k)| M1 Let 0 < λ 1 be a fixed number and assume that v(i, n, t) follow the gradient dynamics (8) and are initialized such that |v(i, n, t = 0)| 1 (192) and assume that the dynamics exists for all times. Then for all i [I], k [K], and all times |w(i, k, t)| R = λ + 2IK. (193) Proof. Note that by assumption S ˆL( v(t = 0)) = N 1 2( ˆf(n) f(n))2 N 2 = NM1. (194) Then we conclude that ˆRλ( v(0)) ˆRλ( v(t)) λ 2 max i [I] max k [K] n Π 1 i (k) |v(i, n, t)|2 2 min i [I],k [K] |Π 1 i (k)| max i [I],k [K] |w(i, k, t)|2 Nλ 4K max i [I],k [K] |w(i, k, t)|2. Thus we conclude that max i [I],k [K] |w(i, k, t)|2 4KM1 λ + 2IK. (196) This ends the proof. Now we state a refined version of Theorem 5 which does not require an a-priori bound but instead proves boundedness of the embeddings v(i, n, t) for all times using Lemma 2. Theorem 8. Let Π be approximately balanced, i.e., assume that Assumption 1 holds. Assume that the function ˆf : RID R is slot-wise linear. Assume that there is M1 > 0 such that sup v1,...,v I B1(0) RD | ˆf(v1, . . . , v I)| M1 2 , max k |g(k)| M1 Let 0 < λ 1 be a fixed number and assume that for λ + 2IK + 3 (198) Assumption 2 holds with Ω= BR(0) and some M > 0. Note that M depends on R and thus on K and λ. Let C1 be the constant from Lemma 1. Assume that v(i, n, t) follow the gradient dynamics (8) and are initialized such that |v(i, n, t = 0)| λ 8C1 . (199) Then there are constants C2, C3 0 depending on M, I, D, λ, and C1 such that for S max C2 N 2KI 1 λ2 ln2(N/λ), C3 NKI 1 λ4 ln(N/λ) (200) δmax(t) max δmax(0)e λt/8, 4C1K(I 1)/2 δmax(0)e λt/8, CI,D,λ,K holds for all t 0. Remark 4. We emphasize again that since M might depend on λ we obtain no explicit rate in λ. However, the S and N dependence is the same as in Theorem 5. Proof. The general strategy of the proof is similar to the proof of Theorem 4, but the proof is slightly simpler, as it is easier to obtain a-priori estimates on the evolution of v(i, n, t). We assume that the conclusion of Theorem 9 holds, which occurs with probability at least 1 S 1 over the randomness of D. Now let T 0 be the largest time such that for all 0 t T and all i [I] and n [N] we have v(i, n, t) Ωand the bound δmax(t) λ 4C1 1 (202) (we assume w.l.o.g. that C1 1) holds for 0 t T. By assumption, those relations hold at t = 0 (note that δmax(t) 2 maxi,n |v(i, n)|). We can bound for t = T using the a-priori estimate (193) from Lemma 2 and find |v(i, n, T)| |w(i, Πi(k), T)| + δmax(T) λ + 2IK + 2 R 1. (203) Thus, by continuity, there is ε > 0 such that v(i, n, t) BR(0) for some ε > 0 and all t [0, T + ε]. Now we can apply Lemma 1 for the interval [T, T + ε]. We note that by slot-wise linearity we have Di,i ˆf = 0 and ω 0. Therefore, we get d dt 1 2|δ(i, n)|2 λ|δ(i, n)|2 + C1K(I 1)/2 r + C1K(I 1)/2 ln(S) S δ2 max + C1δ3 max. Now (with the same reasoning as in the proof of Theorem 4) there is a constant C2 depending on D, I, and M such that for S S0 = (12)2C2 1N 2KI 1 λ2 ln2 (12)2C2 1N 2KI 1 C1K(I 1)/2 ln(S) holds. Clearly, we get for a suitable constant C2 S0 C2 N 2KI 1 λ2 ln2(N/λ). (207) Similarly, we get for S S1 = 2 162C4 1NKI 1 λ4 ln 2 162C4 1NKI 1 C1K(I 1)/2 r holds. And we find S1 C3 NKI 1 λ4 ln(N/λ). (210) Now we can continue to bound (204) for S satisfying (205) and (208) and using (202) as follows d dt 1 2|δ(i, n, T)|2 λ|δ(i, n)|2 + λ2 16C1 δmax + λ 4 δ2 max + λ 4 δ2 max. (211) Now, if δmax(T) λ/(8C1) then there is ε > 0 such that δmax(t) λ/(4C1) holds for t [0, T + ε]. Thus, we assume that δmax(T) λ/(8C1). Let i [I], n [N] be any index such that |δ(i, n, T)| = δmax(T). Then we conclude that d dt 1 2|δ(i, n, T)|2 λ 2 δ2 max + λ2 16C1 δmax λ 2 λ 8C1 δmax + λ2 16C1 δmax = 0. (212) This implies that d dt+ 1 2δmax(T)2 0 (213) from which we conclude that there is ε > 0 such that δmax(t) λ/(4C1) holds for 0 t T + ε. In either case, we get a contradiction and thus T = . In particular, we can apply Lemma 1 for all times t 0. Suppose now that for some t δmax(t) 4C1K(I 1)/2 then we conclude from Lemma 1 that d dt 1 2δ2 max(t) λ 2 δmax(t)2 + C1K(I 1)/2 r 2 δ2 max(t) + λ 4 δ2 max(t) λ 4 δ2 max(t). Thus we conclude from Gronwall s inequality that δmax(t)2 max δmax(0)2e λt/4, 4C1K(I 1)/2 This ends the proof. Proof of Theorem 5. The proof is the same as the proof of Theorem 8 above with the only exception that we do not need to show that v(i, n, t) Ωbut this holds by assumption which makes the proof strictly simpler. G Concentration of Datapoint Statistics Matrices In this section, we prove high probability concentration bounds for certain matrices and vectors capturing the statistics of the dataset. They are used in Section E to extract the asymptotically dominating contribution of the loss gradient and therefore of the gradients of the embeddings under gradient flow. All concentration bounds derived in this section are a simple consequence of the general and standard matrix concentration bound stated in Lemma 7. We start with the matrices Aij RN N with entries Ai,j n1,n2 = |{n D : ni = n1, nj = n2}| =, (217) the matrix counting the appearance of a pair of tokens in slot i and j (this is closely related to the matrices Ak,i,j which we will consider below). Then we have the following result. Lemma 3. Let Ai,j RN N be as defined above. Assume that S 12N. Then the following bound holds for η > 0 Ai,j E(Ai,j) (1 + η) ln(S) This implies N + (1 + η) ln(S) Proof. As in the proof of Theorem 2, we use Poissonization, i.e., we consider dataset D generated by first sample S Poi(S) and then D S U([N]I) S. A uniform sample n U([N]I) satisfies ni = n1 and nj = n2 with probability p = 1 N 2 . (220) Therefore the distribution of Ai,j is Ai,j n1,n2 = Poi S and the entries are independent. We can now apply Lemma 7 where we use η ln(S) as η in the statement of the lemma and find using S N that Ai,j E(Ai,j) (1 + η) ln(S) 4NS η. (222) Let us define r(S, N, η) = (1 + η) ln(S) Note that P( S = S) = SSe S/S! 1/(3 S). This implies PD Ai,j E(Ai,j) r(S, N, η) = P D Ak,i,j E(Ai,j) r(S, N, η)| S = S = P D Ai,j E(Ai,j) r(S, N, η) and S = S P D( S = S) Here we used S 12N in the last step. The relation (219) follows because EAi,j = NS/N 2 = S/N. Next we prove a similar but slightly more involved result for the matrices that capture a similar statistics as Ai,j but which in addition only consider n such that Π(n) = k. We thus consider any k [K]I and any indices i, j I which we consider fixed for now. Consider the sets N(i, k) = Π 1 i (ki), N(j, k) = Π 1 j (kj) [N]. For a given dataset D = {n1, . . . , n S} we define the matrices Ak,i,j RN(i,k) N(j,k) by Ak,i,j n1,n2 = |{n D : Π(n) = k, ni = n1, nj = n2}| (225) for n1 N(i, k), n2 N(j, k). Thus, the entries are the number of datapoints that are mapped to k by Π and whose entries i and j are equal to n1 and n2. We emphasize that this essentially agrees with the definition in (115) given in Section D (except that the dimensions of the matrices do not agree, we address this below the lemma). Similar reasoning as in the previous result, Lemma 3, gives the following statement. Lemma 4. Let Ak,i,j RN(i,k) N(j,k) be as defined above. Assume that for all i [I] and k [K] the bound N/(2K) |Π 1 i (k)| 2N/K holds and let S 24N. Then the following bound holds for η > 0 Ak,i,j E(Ak,i,j) (1 + η) ln(S) max K (I 1)/2 r 3S η+3/2. (226) Note that in particular for K 2 we find for S 24N the bound Ak,i,j E(Ak,i,j) (1 + η) ln(S) S η+3/2. (227) Proof. The proof is essentially the same as the proof of Lemma 3 with some additional notational complications. Again, we consider a dataset D generated by first sampling S Poi(S) and then D S U([N]I) S. Note that a uniform sample n U([N]I) satisfies Π(n) = k, ni = n1, and nj = n2 for n1 N(i, k) and n2 N(j, k) with probability p(k, i, j) = 1 N I |Π 1(k)| |Π 1 i (ki)| |Π 1 j (kj)| (228) and thus the distribution of Ak,i,j is Ak,i,j n1,n2 = Poi S N I |Π 1(k)| |Π 1 i (ki)| |Π 1 j (kj)| =: Poi(λ(k, i, j)) (229) for all n1 N(i, k) and n2 N(j, k) and the entries are independent. Note that the assumption N/(2K) |Π 1 i (k)| 2N/K implies that |Π 1(k)| |Π 1 i (ki)| |Π 1 j (kj)| = Y l [K]\{i,j} |Π 1 l (kl)| KI 2 , N I 2 2I 2KI 2 (230) which implies 1 (2K)I 2 S N 2 λ(k, i, j) 2 N 2 . (231) Now we apply Lemma 7 where we use η ln(S) as η in the statement of the lemma and find that Ak,i,j E(Ak,i,j) (1 + η) ln(S)max K (I 1)/2 r K S η. (232) r(S, N, K, I, η) = (1 + η) ln(S)max K (I 1)/2 r Again, P( S = S) = SSe S/S! 1/(3 S). This implies PD Ak,i,j E(Ak,i,j) r(S, N, K, I, η) = P D Ak,i,j E(Ak,i,j) r(S, N, K, I, η)| S = S = P D Ak,i,j E(Ak,i,j) r(S, N, K, I, η) and S = S P D( S = S) Here we used S 24N in the last step. We also consider the matrix Ak,i,j RIN IN already defined in (115), which we equivalently obtain by embedding Ak,i,j suitable, i.e, we have Ak,i,j (i1,n1),(j1,n2) = Ak,i,j n1,n2 if i1 = i, j1 = j and n1 N(i, k), n2 N(j, k) 0 otherwise. (235) Since Ak,i,j is a submatrix of Ak,i,j we find Ak,i,j EAk,i,j = Ak,i,j EAk,i,j (236) so that Lemma 4 applies to Ak,i,j. We also need a simpler concentration statement for the frequency of datapoints ns such that ni = n. We define the vectors Bi RN by Bi n = |{n D : ni = n}|. (237) We need a simple upper bound on the vectors Bi. Lemma 5. The following bound holds for η > 0 N + max(2 p η ln(S)SN 1, 4/3η ln(S)) S η. (238) Proof. Note that Bi n Bin(S, N 1). The variance of a Ber(p) variable is p(1 p) p. Bernstein s one-sided inequality then reads P Bin(S, N 1) SN 1 t exp t2 2SN 1 + 2t/3 4SN 1 , 3t/4 . Now we apply this bound with t = max(2 p η ln(S)SN 1, 4/3η ln(S)) which implies Bin(S, N 1) S N , 4/3η ln(S) As before for the A matrices we also need an extension of the previous result to a setting where we in addition require Π(n) = k for a given k and n [N]. Recall that N(i, k) = Π 1 i (ki). Then we consider the vector Bk,i RN(i,k) given by Bk,i n = |{n D : Π(n) = k, ni = n}| = |Dk,i,n|. (241) Again, this agrees with the definition of B given in (114). The following Lemma holds. Lemma 6. Let Bk,i RN(i,k) be as defined above. Assume that for all i [I] and k [K] the bound |Π 1 i (k)| 2N/K holds. Then the following bound holds for η > 0 and n N(i, k) |Bk,i n E(Bk,i n )| max 2S η. (242) Note that viewing I and K as constant we get for S N ln(N) the bound |Bk,i n E(Bk,i n )| CI,K 2S η. (243) Proof. The proof is along the lines of Lemma 5 but slightly more technical. Note that the entries of Bk,i are distributed according to Bin(S, p(k, i)) where p(k, i) = |Π 1(k)| N I|Π 1 i (ki)|. (244) Note that by assumption I 1 N I = 1 I 1 . (245) Since the variance of a Ber(p) variable is p(1 p) p Bernstein s inequality then implies P(|Bin(S, p(k, i)) S p(k, i)| t) 2 exp t2 2Sp(k, i) + 2t/3 2 exp min t2 4Sp(k, i), 3t/4 . (246) Now we apply this bound with t = max(2 p η ln(S)Sp(k, i), 4/3η ln(S)) which implies P |Bin(S, p(k, i)) S p(k, i)| max(2 p η ln(S)Sp(k, i), 4/3η ln(S)) 2S η. (247) Applying (245) we find |Bin(S, p(k, i)) S p(k, i)| max To conclude that (243) holds, we just need to show that for S N ln(N) the expression S/(N ln(S)) is bounded which follows by monotonicity of S/ ln(S) in S. Recall that we introduced Bk,i R[N] in (114) and then Bk,i n = Bk,i n if n N(i, k) 0 otherwise. (249) For our analysis we want to summarize all those concentration bounds in one result. To simplify the statement, we only consider our main regime of interest. Theorem 9. Assume that the bound N/(2K) |Π 1 i (k)| 2N/K holds for all i [I] and k [K]. Assume that S N max 24, 3(K/2)I 1 ln(N), 4IKI/2 . (250) Then with probability at least 1 S 1 the following bounds hold simultaneously for all k [K]I, i, j [I] Ai,j E(Ai,j) 5 ln(S) S N , (251) Ak,i,j E(Ak,i,j) 6 ln(S) 2 K (I 1)/2 r S N , (252) |Bk,i n E(Bk,i n )| 4 2 Proof. Applying Lemma 3 with η = 4 we obtain that the first bound holds with probability at least 1 S 2 (here we used S 12N). Similarly, we obtain for S (K/2)I 1N K (I 1)/2 r K (I 1)/2 r S N . (255) Setting η = 5 in Lemma 4 we get that the second bound holds (by the union bound) with probability at least 1 3I2KIS 3 for all i, j [I] and k [K]I. For S αN ln(N) and N α we find S ln(S) αN ln(N) ln(α ln(N)N) αN ln(N) which implies for η = 2 and S 3N ln(N) N , 4/3η ln(S) and thus applying Lemma 5 with η = 2 implies that the third bound holds with probability at least 1 S 2. Finally, we find for S 3(K/2)I 1N ln(N) and η = 3 the bound Lemma 6 implies that the last bound holds with probability at least 1 2IKIS 3. All bounds simultaneously then hold with probability at least 1 S 2 3I2KIS 3 S 2 2IKIS 3 1 2S 2 2 1 4S 1 1 S 1. (259) H Spectral bounds for Random Matrices In the derivation of the concentration bounds in the previous section, we needed concentration bounds for random matrices whose entries follow a Poisson distribution. In this section we provide the required result. Lemma 7 below should be folklore, but we did not find an exact reference so we provide a proof based on standard concentration results for random matrices. Let us first state the general result. Theorem 10 (Corollary 3.7 in [26]). Let Xk be a sequence of independent random symmetric matrices with dimension d. We assume that there is a function g : (0, ) [0, ] and symmetric matrices Ak such that EeθXk eg(θ)Ak for θ > 0. (260) ρ := λmax( X k Ak). (261) Then for all t R the following bound holds d inf θ>0 e θt+g(θ)ρ. (262) We now apply the previous concentration bound to our specific setting of interest. Lemma 7. Consider a random matrix A Rd1 d2 whose entries are independent random variables with Aij Poi(λ) for some λ > 0. Then the following bound holds for all η 0 P A E(A) (η + 1) max p max(d1, d2)λ, 1 2(d1 + d2)e η. (263) Proof. To bound the norm of this non-symmetric matrix we use the usual approach to consider Q = 0d1 d1 A E(A) A E(A) 0d1 d1 Consider the index set I = {(i, j) : i [d1], j {d1 + 1, . . . , d2 + d2}} (265) then we can write (i,j) I (Ni,j λ)(Ei,j + Ej,i) (266) where Ei,j is the matrix whose entry (i, j) is 1 and all other entries vanish and Ni,j Poi(λ). Let Xij = (Ni,j λ)(Ei,j + Ej,i). Note that by induction one directly finds (Ei,j + Ej,i)k = Eii + Ejj if k is even, Ei,j + Ej,i if k is odd. (267) In any case we get (Ei,j + Ej,i)k Ei,i + Ej,j in the sense of symmetric matrices and we set Aij = Ei,i + Ej,j. Now we obtain for any θ R the relation EeθXij = Aij θ2k E(Poi(λ) λ)2k (2k)! + (Ei,j + Ej,i) θ2k+1E(Poi(λ) λ)2k+1 θ2k E(Poi(λ) λ)2k θ2k+1E(Poi(λ) λ)2k+1 Note that the first summand is invariant under θ θ while the term in absolute values changes its sign. This implies that X θ2k E(Poi(λ) λ)2k θ2k+1E(Poi(λ) λ)2k+1 θk E(Poi(λ) λ)k ( θ)k E(Poi(λ) λ)k = max Eeθ(Poi(λ) λ), Ee θ(Poi(λ) λ) = max eλ(eθ 1) λθ, eλ(e θ 1)+λθ = eλ(e|θ| 1) λ|θ|. Here we used the moment generating function of the Poisson distribution and in the last step that λ 0 and then for θ 0 eθ e θ = 2 sinh(θ) 2θ. (270) Thus we infer from (268) and (269) that EeθXij eλ(e|θ| 1) λ|θ|Aij = e(λ(e|θ| 1) λ|θ|)Aij := egλ(|θ|)Aij. (271) using that Ak ij = Aij in the second step. Note that here (i,j) I Aij = X (i,j) I Ei,i + Ej,j = i=1 d2Ei,i + i=d1+1 d1Ei,i (272) which implies (i,j) I Aij = max(d1, d2). (273) Thus we can apply Theorem 10 and get (i,j) I Xij (d1 + d2) inf θ>0 e θt+g(θ)ρ (d1 + d2) inf θ>0 e θt+λ max(d1,d2)(eθ 1 θ). If max(d1, d2)λ 1 we apply this bound with t = η + 1 and θ = 1 and find (i,j) I Xij (d1 + d2) exp (1 + η) + λ max(d1, d2)(e1 2) (d1 + d2)e η. (275) If max(d1, d2)λ 1 we set θ = p max(d1, d2)λ 1 1 and t = (η + 1) p max(d1, d2)λ and find using that for 0 θ 1 the bound eθ 1 θ θ2 holds (i,j) I Xij max(d1, d2)λ exp θt + λ max(d1, d2)θ2 (d1 + d2) exp max(d1, d2)λ p max(d1, d2)λ + λ max(d1, d2) p max(d1, d2)λ 2 ! = (d1 + d2)e η. Thus we can combine both cases and obtain the relation P λmax (Q) (η + 1) max p max(d1, d2)λ, 1 (d1 + d2)e η. (277) We observe that by (271) Eeθ( Xij) = Ee θXij egλ(|θ|)Aij. (278) Thus the same reasoning shows that P λmin (Q) (η + 1) max p max(d1, d2)λ, 1 (d1 + d2)e η. (279) and since A EA = Q = max |λi(Q)| the claim follows by applying the union bound over (277) and (279). I Kronecker Products Let us here state some basic properties of the Kronecker product of matrices. For two matrices A Rn m and B Rp q the Kronecker product is defined by a11B a1m B ... ... ... an1B anm B Rnp mq. (280) We will need the property that the operator norm satisfies A B A B . (281) Note that we use the notation v w = vw to denote the outer product. Formally we have vw = v w but it is convenient and common to drop the transposed here. Neur IPS Paper Checklist Question: Do the main claims made in the abstract and introduction accurately reflect the paper s contributions and scope? Answer: [Yes] Justification: See main results and Section 3. Guidelines: The answer NA means that the abstract and introduction do not include the claims made in the paper. The abstract and/or introduction should clearly state the claims made, including the contributions made in the paper and important assumptions and limitations. A No or NA answer to this question will not be perceived well by the reviewers. The claims made should match theoretical and experimental results, and reflect how much the results can be expected to generalize to other settings. It is fine to include aspirational goals as motivation as long as it is clear that these goals are not attained by the paper. 2. Limitations Question: Does the paper discuss the limitations of the work performed by the authors? Answer: [Yes] Justification: See discussions in Section 2 and Section 4 and Section 5. Guidelines: The answer NA means that the paper has no limitation while the answer No means that the paper has limitations, but those are not discussed in the paper. The authors are encouraged to create a separate "Limitations" section in their paper. The paper should point out any strong assumptions and how robust the results are to violations of these assumptions (e.g., independence assumptions, noiseless settings, model well-specification, asymptotic approximations only holding locally). The authors should reflect on how these assumptions might be violated in practice and what the implications would be. The authors should reflect on the scope of the claims made, e.g., if the approach was only tested on a few datasets or with a few runs. In general, empirical results often depend on implicit assumptions, which should be articulated. The authors should reflect on the factors that influence the performance of the approach. For example, a facial recognition algorithm may perform poorly when image resolution is low or images are taken in low lighting. Or a speech-to-text system might not be used reliably to provide closed captions for online lectures because it fails to handle technical jargon. The authors should discuss the computational efficiency of the proposed algorithms and how they scale with dataset size. If applicable, the authors should discuss possible limitations of their approach to address problems of privacy and fairness. While the authors might fear that complete honesty about limitations might be used by reviewers as grounds for rejection, a worse outcome might be that reviewers discover limitations that aren t acknowledged in the paper. The authors should use their best judgment and recognize that individual actions in favor of transparency play an important role in developing norms that preserve the integrity of the community. Reviewers will be specifically instructed to not penalize honesty concerning limitations. 3. Theory Assumptions and Proofs Question: For each theoretical result, does the paper provide the full set of assumptions and a complete (and correct) proof? Answer: [Yes] Justification: All assumptions are stated and complete proofs can be found in the sumpplementary material. Guidelines: The answer NA means that the paper does not include theoretical results. All the theorems, formulas, and proofs in the paper should be numbered and crossreferenced. All assumptions should be clearly stated or referenced in the statement of any theorems. The proofs can either appear in the main paper or the supplemental material, but if they appear in the supplemental material, the authors are encouraged to provide a short proof sketch to provide intuition. Inversely, any informal proof provided in the core of the paper should be complemented by formal proofs provided in appendix or supplemental material. Theorems and Lemmas that the proof relies upon should be properly referenced. 4. Experimental Result Reproducibility Question: Does the paper fully disclose all the information needed to reproduce the main experimental results of the paper to the extent that it affects the main claims and/or conclusions of the paper (regardless of whether the code and data are provided or not)? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not include experiments. If the paper includes experiments, a No answer to this question will not be perceived well by the reviewers: Making the paper reproducible is important, regardless of whether the code and data are provided or not. If the contribution is a dataset and/or model, the authors should describe the steps taken to make their results reproducible or verifiable. Depending on the contribution, reproducibility can be accomplished in various ways. For example, if the contribution is a novel architecture, describing the architecture fully might suffice, or if the contribution is a specific model and empirical evaluation, it may be necessary to either make it possible for others to replicate the model with the same dataset, or provide access to the model. In general. releasing code and data is often one good way to accomplish this, but reproducibility can also be provided via detailed instructions for how to replicate the results, access to a hosted model (e.g., in the case of a large language model), releasing of a model checkpoint, or other means that are appropriate to the research performed. While Neur IPS does not require releasing code, the conference does require all submissions to provide some reasonable avenue for reproducibility, which may depend on the nature of the contribution. For example (a) If the contribution is primarily a new algorithm, the paper should make it clear how to reproduce that algorithm. (b) If the contribution is primarily a new model architecture, the paper should describe the architecture clearly and fully. (c) If the contribution is a new model (e.g., a large language model), then there should either be a way to access this model for reproducing the results or a way to reproduce the model (e.g., with an open-source dataset or instructions for how to construct the dataset). (d) We recognize that reproducibility may be tricky in some cases, in which case authors are welcome to describe the particular way they provide for reproducibility. In the case of closed-source models, it may be that access to the model is limited in some way (e.g., to registered users), but it should be possible for other researchers to have some path to reproducing or verifying the results. 5. Open access to data and code Question: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material? Answer: [NA] Justification: Guidelines: The answer NA means that paper does not include experiments requiring code. Please see the Neur IPS code and data submission guidelines (https://nips.cc/ public/guides/Code Submission Policy) for more details. While we encourage the release of code and data, we understand that this might not be possible, so No is an acceptable answer. Papers cannot be rejected simply for not including code, unless this is central to the contribution (e.g., for a new open-source benchmark). The instructions should contain the exact command and environment needed to run to reproduce the results. See the Neur IPS code and data submission guidelines (https: //nips.cc/public/guides/Code Submission Policy) for more details. The authors should provide instructions on data access and preparation, including how to access the raw data, preprocessed data, intermediate data, and generated data, etc. The authors should provide scripts to reproduce all experimental results for the new proposed method and baselines. If only a subset of experiments are reproducible, they should state which ones are omitted from the script and why. At submission time, to preserve anonymity, the authors should release anonymized versions (if applicable). Providing as much information as possible in supplemental material (appended to the paper) is recommended, but including URLs to data and code is permitted. 6. Experimental Setting/Details Question: Does the paper specify all the training and test details (e.g., data splits, hyperparameters, how they were chosen, type of optimizer, etc.) necessary to understand the results? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not include experiments. The experimental setting should be presented in the core of the paper to a level of detail that is necessary to appreciate the results and make sense of them. The full details can be provided either with the code, in appendix, or as supplemental material. 7. Experiment Statistical Significance Question: Does the paper report error bars suitably and correctly defined or other appropriate information about the statistical significance of the experiments? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not include experiments. The authors should answer "Yes" if the results are accompanied by error bars, confidence intervals, or statistical significance tests, at least for the experiments that support the main claims of the paper. The factors of variability that the error bars are capturing should be clearly stated (for example, train/test split, initialization, random drawing of some parameter, or overall run with given experimental conditions). The method for calculating the error bars should be explained (closed form formula, call to a library function, bootstrap, etc.) The assumptions made should be given (e.g., Normally distributed errors). It should be clear whether the error bar is the standard deviation or the standard error of the mean. It is OK to report 1-sigma error bars, but one should state it. The authors should preferably report a 2-sigma error bar than state that they have a 96% CI, if the hypothesis of Normality of errors is not verified. For asymmetric distributions, the authors should be careful not to show in tables or figures symmetric error bars that would yield results that are out of range (e.g. negative error rates). If error bars are reported in tables or plots, The authors should explain in the text how they were calculated and reference the corresponding figures or tables in the text. 8. Experiments Compute Resources Question: For each experiment, does the paper provide sufficient information on the computer resources (type of compute workers, memory, time of execution) needed to reproduce the experiments? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not include experiments. The paper should indicate the type of compute workers CPU or GPU, internal cluster, or cloud provider, including relevant memory and storage. The paper should provide the amount of compute required for each of the individual experimental runs as well as estimate the total compute. The paper should disclose whether the full research project required more compute than the experiments reported in the paper (e.g., preliminary or failed experiments that didn t make it into the paper). 9. Code Of Ethics Question: Does the research conducted in the paper conform, in every respect, with the Neur IPS Code of Ethics https://neurips.cc/public/Ethics Guidelines? Answer: [Yes] Justification: Guidelines: The answer NA means that the authors have not reviewed the Neur IPS Code of Ethics. If the authors answer No, they should explain the special circumstances that require a deviation from the Code of Ethics. The authors should make sure to preserve anonymity (e.g., if there is a special consideration due to laws or regulations in their jurisdiction). 10. Broader Impacts Question: Does the paper discuss both potential positive societal impacts and negative societal impacts of the work performed? Answer: [NA] Justification: Guidelines: The answer NA means that there is no societal impact of the work performed. If the authors answer NA or No, they should explain why their work has no societal impact or why the paper does not address societal impact. Examples of negative societal impacts include potential malicious or unintended uses (e.g., disinformation, generating fake profiles, surveillance), fairness considerations (e.g., deployment of technologies that could make decisions that unfairly impact specific groups), privacy considerations, and security considerations. The conference expects that many papers will be foundational research and not tied to particular applications, let alone deployments. However, if there is a direct path to any negative applications, the authors should point it out. For example, it is legitimate to point out that an improvement in the quality of generative models could be used to generate deepfakes for disinformation. On the other hand, it is not needed to point out that a generic algorithm for optimizing neural networks could enable people to train models that generate Deepfakes faster. The authors should consider possible harms that could arise when the technology is being used as intended and functioning correctly, harms that could arise when the technology is being used as intended but gives incorrect results, and harms following from (intentional or unintentional) misuse of the technology. If there are negative societal impacts, the authors could also discuss possible mitigation strategies (e.g., gated release of models, providing defenses in addition to attacks, mechanisms for monitoring misuse, mechanisms to monitor how a system learns from feedback over time, improving the efficiency and accessibility of ML). 11. Safeguards Question: Does the paper describe safeguards that have been put in place for responsible release of data or models that have a high risk for misuse (e.g., pretrained language models, image generators, or scraped datasets)? Answer: [NA] Justification: Guidelines: The answer NA means that the paper poses no such risks. Released models that have a high risk for misuse or dual-use should be released with necessary safeguards to allow for controlled use of the model, for example by requiring that users adhere to usage guidelines or restrictions to access the model or implementing safety filters. Datasets that have been scraped from the Internet could pose safety risks. The authors should describe how they avoided releasing unsafe images. We recognize that providing effective safeguards is challenging, and many papers do not require this, but we encourage authors to take this into account and make a best faith effort. 12. Licenses for existing assets Question: Are the creators or original owners of assets (e.g., code, data, models), used in the paper, properly credited and are the license and terms of use explicitly mentioned and properly respected? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not use existing assets. The authors should cite the original paper that produced the code package or dataset. The authors should state which version of the asset is used and, if possible, include a URL. The name of the license (e.g., CC-BY 4.0) should be included for each asset. For scraped data from a particular source (e.g., website), the copyright and terms of service of that source should be provided. If assets are released, the license, copyright information, and terms of use in the package should be provided. For popular datasets, paperswithcode.com/datasets has curated licenses for some datasets. Their licensing guide can help determine the license of a dataset. For existing datasets that are re-packaged, both the original license and the license of the derived asset (if it has changed) should be provided. If this information is not available online, the authors are encouraged to reach out to the asset s creators. 13. New Assets Question: Are new assets introduced in the paper well documented and is the documentation provided alongside the assets? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not release new assets. Researchers should communicate the details of the dataset/code/model as part of their submissions via structured templates. This includes details about training, license, limitations, etc. The paper should discuss whether and how consent was obtained from people whose asset is used. At submission time, remember to anonymize your assets (if applicable). You can either create an anonymized URL or include an anonymized zip file. 14. Crowdsourcing and Research with Human Subjects Question: For crowdsourcing experiments and research with human subjects, does the paper include the full text of instructions given to participants and screenshots, if applicable, as well as details about compensation (if any)? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not involve crowdsourcing nor research with human subjects. Including this information in the supplemental material is fine, but if the main contribution of the paper involves human subjects, then as much detail as possible should be included in the main paper. According to the Neur IPS Code of Ethics, workers involved in data collection, curation, or other labor should be paid at least the minimum wage in the country of the data collector. 15. Institutional Review Board (IRB) Approvals or Equivalent for Research with Human Subjects Question: Does the paper describe potential risks incurred by study participants, whether such risks were disclosed to the subjects, and whether Institutional Review Board (IRB) approvals (or an equivalent approval/review based on the requirements of your country or institution) were obtained? Answer: [NA] Justification: Guidelines: The answer NA means that the paper does not involve crowdsourcing nor research with human subjects. Depending on the country in which research is conducted, IRB approval (or equivalent) may be required for any human subjects research. If you obtained IRB approval, you should clearly state this in the paper. We recognize that the procedures for this may vary significantly between institutions and locations, and we expect authors to adhere to the Neur IPS Code of Ethics and the guidelines for their institution. For initial submissions, do not include any information that would break anonymity (if applicable), such as the institution conducting the review.