# feature_collapse__790f6e0d.pdf Published as a conference paper at ICLR 2024 FEATURE COLLAPSE Thomas Laurent1, James H. von Brecht, Xavier Bresson2 1 Loyola Marymount University, tlaurent@lmu.edu 2 National University of Singapore, xaviercs@nus.edu.sg We formalize and study a phenomenon called feature collapse that makes precise the intuitive idea that entities playing a similar role in a learning task receive similar representations. As feature collapse requires a notion of task, we leverage a synthetic task in which a learner must classify sentences constituted of L tokens. We start by showing experimentally that feature collapse goes hand in hand with generalization. We then prove that, in the large sample limit, distinct tokens that play identical roles in the task receive identical local feature in the first layer of the network. This analysis shows that a neural network trained on this task provably learns interpretable and meaningful representations in its first layer. 1 INTRODUCTION Many machine learning practices implicitly rely on the belief that good generalization and transfer learning requires good features. Despite this, the notion of good features remains vague and carries many potential meanings. One definition is that features/representations should only encode the information necessary to do the task at hand, and discard any unnecessary information as noise. For example, two distinct patches of grass should map to essentially identical representations even if these patches differ in pixel space. Intuitively, a network that gives the same representation to many distinct patches of grass has learned the grass concept. We call this phenomenon feature collapse, meaning that a learner gives the same feature to entities that play similar roles for the task at hand. This phenomenon captures the common intuition, often confirmed by practice, that the ability to learn good representations in early layers is essential for the empirical success of neural networks. Broadly speaking, we conduct a theoretical investigation into how such representations are learned in the early layers of a neural network. To make progress we adopt a common approach in theoretical deep-learning community. One starts with a synthetic data model exhibiting a clear latent structure, and then prove that a specific neural network architecture successfully uncovers this latent structure during training. For example, recent work in representation learning (Damian et al., 2022; Mousavi Hosseini et al., 2022) leverages a data model (the multiple-index model) with latent structure defined by a low dimensional subspace. In this specific setting, the first layer of a fully connected neural network provably learns this low-dimensional subspace. Our work follows along the same lines. Specifically, we start with a data model that asks a learner to classify sentences comprised of L tokens. Some of these tokens play identical roles in the sense that replacing one with another does not change the label of the sentence, and this equivalence of tokens defines the latent structure in the data. We then consider a neural network containing a shared embedding module, optionally followed by a Layer Norm module, and then a final linear classification head. We show that this network, when equipped with the Layer Norm module, provably learn the equivalence of these tokens in its first layer. To do so, we consider the large sample limit and derive analytical formulas for the weights of the network trained on the data model. Under some symmetry assumptions on the task, these analytical formulas reveal: (i) If the network includes Layer Norm then feature collapse takes place. The neural network learns to give the same embedding to tokens that play identical roles. (ii) If the network does not include Layer Norm then feature collapse fails. The network does not give the same embedding to tokens that play identical roles. Moreover, this failure stems from the fact that common tokens receive large embeddings and rare tokens receive small embeddings. Published as a conference paper at ICLR 2024 Finally, we conduct experiments that show feature collapse and generalization go hand in hand. These experiments demonstrate that for the network to generalize well it is essential for tokens playing identical roles to receive the same embeddings. In summary, our main contributions are as follow: We study how a network learns representations in its first layer. To do so, we make the notion of good features mathematically rigorous via a synthetic data model with latent structure. We derive analytical formulas for the weights of a two-layer network trained on this data model. These analytical formulas show that, when equipped with a Layer Norm module, the network provably learns interpretable and meaningful representations in its first layer. The remainder of the paper proceeds as follows: In subsection 1.1 we discuss related works; In section 2 we describe the data model and present a set of visual experiments to illustrate the main ideas used in the paper; In section 3 we present the three theorems that constitute the main results of the paper; Finally, section 4 outlines a set of additional experiments performed in the appendix. 1.1 RELATED WORKS Our work most closely resembles recent work on theoretical representation learning and, to some extent, the recent literature on neural collapse. The works by Damian et al. (2022) and Mousavi Hosseini et al. (2022) consider a synthetic data model, the multiple-index model, to investigate representation learning. In this model, the learner must solve a regression task with normally distributed inputs x Rd and targets y = g( u1, x , . . . , ur, x ) for some function g : Rr R, some vectors u1, . . . , ur Rd, and with r d. The target y therefore solely depends on the projection of x on the low-dimensional subspace spanned by u1, . . . , ur. These works prove that a fully connected, two-layer neural network learns, in its first layer, the low dimensional subspace. The behavior of fully connected networks trained on the multiple-index model, or on related data models, has also been studied in various other works, including Ba et al. (2022), Bietti et al. (2022), Abbe et al. (2022) and Parkinson et al. (2023). Both this line of investigation and our work prove, in the appropriate sense, that a network uncovers latent structure. Nevertheless, our data model differs quite substantially from the multi-index model. Moreover, we do not study a fully connected network. Our network has shared embedding module, and this allows us to tie good features to a notion of semantic equivalence. The phenomenon we study in this work, feature collapse, has superficial similarity to neural collapse but in detail is quite different. In a pioneering work, Papyan et al. (2020) conducted a series of experiments that revealed that a well-trained network gives identical representations, in its last layer, to training points that belong to the same class. In a K-class classification task we therefore see the emergence, in the last layer, of K vectors coding for the K classes. Additionally, these K vectors point in maximally opposed directions. This phenomenon, coined neural collapse, has been studied extensively since its discovery (e.g. Mixon et al. (2020); Lu & Steinerberger (2020); Wojtowytsch et al. (2020); Fang et al. (2021); Zhu et al. (2021); Ji et al. (2021); Tirer & Bruna (2022); Zhou et al. (2022)). To emphasize the difference with feature collapse, note that neural collapse refers to a phenomenon where all training points from the same class receive the same representation at the end of the network. Unlike feature collapse, this does not provide any indication that the network has learned good representations in its early layers, or that the neural network has uncovered some latent structure beyond the one encoded in the training labels. 2 A TALE OF FEATURE COLLAPSE We begin by more fully telling the empirical tale that motivates our theoretical investigation. To make progress we adopt a common approach in theoretical deep-learning and leverage a synthetic data model exhibiting a clear latent structure. The model generates sequences of length L from some underlying set of latent variables that encode the K classes of a classification task. To make it concrete we describe this generative process using NLP terminology like sentences and words , but of course do not intend it as a realistic language model. Figure 1 illustrates the basic idea. The left side of the figure depicts a vocabulary of nw = 12 word tokens and nc = 3 concept tokens V = {potato, cheese, carrots, chicken, . . .} and C = {vegetable, dairy, meat} Published as a conference paper at ICLR 2024 Latent variables 12 words partitioned into 3 concepts butter yogurt carrot potato veggie distribution Class 2 Class 3 Rn8b Txm Hj RZM1vz V/NH8uq Hu NWv MY/b Oav/4CZ9x KRg=z1 = [ dairy, veggie, meat, veggie, dairy ] z2 = [ meat, dairy, dairy, veggie, meat ] z3 = [ veggie, meat, dairy, meat, dairy ] meat distribution dairy distribution x11 = [ cheese, carrot, pork, potato, butter ] x12 = [ butter, potato, chicken, lettuce, cheese ] x13 = [ yogurt, lettuce, beef, leek, cheese ] x14 = [ cream, lettuce, chicken, potato, yogurt ] x21 = [ chicken, cheese, butter, lettuce, lamb ] x22 = [ pork, cheese, cheese, carrot, pork ] x23 = [ chicken, butter, cheese, potato, beef ] x24 = [ chicken, cream, butter, potato, pork ] x31 = [ potato, pork, yogurt, chicken, butter ] x32 = [ carrot, beef, cheese, pork, cheese ] x33 = [ lettuce, chicken, cheese, chicken, yogurt ] x34 = [ potato, chicken, cheese, pork, cheese ] Figure 1: Data model with parameters set to L = 5, nw = 12, nc = 3, K = 3. with the 12 words partitioned into the 3 equally sized concepts. A sentence x VL is a sequence of L words (L = 5 on the figure), and a latent variable z CL is a sequence of L concepts. The latent variables generate sentences. For example z = [ dairy, veggie, meat, veggie, dairy ] generates x = [ cheese, carrot, pork, potato, butter ] with the sentence on the right obtained by sampling each word at random from the corresponding concept. The first word represents a random sample from the dairy concept (butter, cheese, cream, yogurt) according to the dairy distribution (square box at left), the second word represents a random sample from the vegetable concept (potato, carrot, leek, lettuce) according to the vegetable distribution, and so forth. At right, figure 1 depicts a classification task with K = 3 categories prescribed by the three latent variables z1, z2, z3 CL. Sentences generated by the latent variable zk share the same label k, yielding a classification problem that requires a learner to classify sentences among K categories. [ cheese, butter, lettuce, chicken, leek ] concatenate output h W,U [ cheese, butter, lettuce, chicken, leek ] concatenate norm norm norm norm norm E7Bgyuowy0w Ac CAp7h Fd4c5bw4787Hv LXg5DOH8Af O5w8e Bo/uh Figure 2: Networks We use two similar networks to empirically study if and when the feature collapse phenomenon occurs. The first network x 7 h W, U(x), depicted on the top panel of figure 2, starts by embedding each word in a sentence by applying a d nw matrix W to the one-hot representation of each word. It then concatenates these d-dimensional embeddings of each word into a single vector. Finally, it applies a linear transformation U to produce a Kdimensional score vector y = h W, U(x) with one entry for each of the K classes. The d nw embedding matrix W and the K Ld matrix U of linear weights are the only learnable parameters, and the network has no nonlinearities. The second network x 7 h W, U(x), depicted at bottom, differs only by the application of a Layer Norm module to the word embeddings prior to the concatenation. For simplicity we use a Layer Norm module which does not contain any learnable parameters; the module simply removes the mean and divides by the standard deviation of its input vector. As for the first network, the only learnable weights are W and U. The task depicted on figure 1, and the networks depicted on figure 2, provide a clear way of studying how interpretable and meaningful representations are learned in the first layer of a network. For example, the four words butter, cheese, cream and yogurt clearly play identical role for the task at hand (replacing one with another does not change the label of the sentence). As a consequence we would expect the embedding layer of the network to map them to the same representation. Similarly, the words belonging to the vegetable concepts should receive same representation, and the words belonging to the meat concept should receive same representation. If this takes place, we say that feature collapse has occurred. If feature collapse occurs, this will also be reflected in the second layer of the network. To see this, partition the linear transformation u1,1 u1,2 u1,L u2,1 u2,2 u2,L ... ... ... u K,1 u K,2 u K,L Published as a conference paper at ICLR 2024 into its components uk,ℓ Rd. Suppose for example that zk,ℓ= veggie, meaning that the latent variable zk contains the veggie concept in the ℓth position. If W properly encodes concepts then we expect the vector uk,ℓto give a strong response when presented with the embedding of a word that belongs to the veggie concept. So we would expect uk,ℓto align with the embeddings of the words that belong to the veggie concept, and so feature collapse would occur in this manner as well. In the remainder of this section we conduct experiments that visually illustrate the feature collapse phenomenon and the formation of interpretable representations in networks h and h . These experiments also show how layer normalization plays a key role in the feature collapse phenomenon. In our experiments we use the standard cross entropy loss ℓ(y, k), with y RK and 1 k K, and then minimize the corresponding regularized empirical risks Remp(W, U) = 1 i=1 ℓ h W,U (xk,i) , k + λ 2 U 2 F + λ 2 W 2 F (2) R emp(W, U) = 1 i=1 ℓ h W,U (xk,i) , k + λ 2 U 2 F (3) of each network via stochastic gradient descent. The xk, i denote the i-th sentence of the kth category in the training set, and so each of the K categories has nspl representatives. For the parameters of the architecture, loss, and training procedure, we use an embedding dimension of d = 100, a weight decay of λ = 0.001, a mini-batch size of 100 and a constant learning rate 0.1, respectively, for all experiments. The Layer Norm module implicitly regularizes the matrix W so we do not penalize it in equation (3). The codes for our experiments are available at https://github.com/xbresson/feature_collapse. Remark: Without weight decay (i.e. λ = 0), the above objectives typically do not have global minima. We therefore focus our theoretical investigation on the case λ > 0 which is analytically more tractable. In appendix A.1 we provide an empirical investigation of the case without weight decay to show that both cases (i.e. λ > 0 and λ = 0) exhibit the same behavior in practice. 2.1 THE UNIFORM CASE We start with an instance of the task from figure 1 with parameters nc = 3, nw = 1200, L = 15, K = 1000, and with uniform word distributions. So each of the 3 concepts (say veggie, dairy, and meat) contain 400 words and the corresponding distributions (the veggie distribution, the dairy distribution, and the meat distribution) are uniform. We form K = 1000 latent variables z1, . . . , z1000 by selecting them uniformly at random from the set CL, which simply means that any concept sequence z = [z1, . . . , z L] has an equal probability of occurrence. We then construct a training set by generating nspl = 5 data points from each latent variable. We then train both networks h, h and evaluate their generalization performance; both achieve 100% accuracy on test points. Figure 3: W & U Since both networks generalize perfectly, we expect them to have learned good representations. To confirm this, we start by visualizing in figure 3 the learnable parameters W, U of the network h W, U after training. The embedding matrix W contains nw = 1200 columns. Each column is a vector in R100 and corresponds to a word embedding. The top panel of figure 3 depicts these 1200 word embeddings after dimensionality reduction via PCA. The top singular values σ1 = 34.9, σ2 = 34.7 and σ3 = 0.001 associated with the PCA indicate that the word embeddings essentially live in a 2 dimensional subspace of R100, and so the PCA paints an accurate picture of the distribution of word embeddings. We then color code each word embedding accorded to its concept, so that all embeddings of words within a concept receive the same color (say all veggie words in green, all dairy words in blue, and so forth). As the figure illustrates, words from the same concept receive nearly identical embeddings, and these embeddings form an equilateral triangle or two-dimensional simplex. We therefore observe collapse of features into a set of nc = 3 equi-angular vectors at the level of word embeddings. The bottom panel of figure 3 illustrates collapse for the parameters U of the linear layer. We partition the matrix U into vectors uk,ℓ R100 via (1) and visualize them once again with PCA. As for the word embeddings, the singular values of the PCA (σ1 = 34.9, σ2 = 34.6 and σ3 = 0.0003) reveal Published as a conference paper at ICLR 2024 that the vectors uk,ℓessentially live in a two dimensional subspace of R100. We color code each uk,ℓaccording to the concepts contained in the corresponding latent variable (say uk,ℓis green if zk,ℓ= veggie, and so forth). The figure indicates that vectors uk,ℓthat correspond to a same concept collapse around a single vector. A similar analysis applied to the weights of the network h W,U tells the same story, provided we examine the actual word features (i.e. the embeddings after the Layer Norm) rather than the weights W themselves. In theorem 1 and 3 (see section 3) we prove the correctness of this empirical picture. We show that the weights of h and h collapse into the configurations illustrated on figure 3 in the large sample limit. Moreover, this limit captures the empirical solution very well. For example, the word embeddings in figure 3 have a norm equal to 1.41 0.13, while we predict a norm of 1.42214 theoretically. 2.2 THE LONG-TAILED CASE At a superficial glance it appears as if the Layer Norm module plays no essential role, as both networks h and h , in the previous experiment, exhibit feature collapse and generalize perfectly. To probe this issue further, we continue our investigation by conducting a similar experiment (keeping nc = 3, nw = 1200, L = 15, and K = 1000) but with non-uniform, long-tailed word distributions within each of the nc = 3 concepts. For concreteness, say the veggie concept contains the 400 words potato, lettuce, . . . . . . , arugula, parsnip, . . . . . . , achojcha where achojcha is a rare vegetable that grows in the Andes mountains. We form the veggie distribution by sampling potato with probability C/1, sampling lettuce with probability C/2, and so forth down to achojcha that has probability C/400 of being sampled (C is chosen so that all the probabilities sum to 1). This 1/i power law distribution has a long-tail, meaning that relatively infrequent words such as arugula or parsnip collectively capture a significant portion of the mass. Natural data in the form of text or images typically exhibit long-tailed distributions (Salakhutdinov et al., 2011; Zhu et al., 2014; Liu et al., 2019; Feldman, 2020; Feldman & Zhang, 2020). For instance, the frequencies of words in natural text approximately conform to the 1/i power law distribution, also known as Zipf s law (Zipf, 1935), which motivates the specific choice made in this experiment. Many datasets of interest display some form of long-tail behavior, whether at the level of object occurrences in computer vision or the frequency of words or topics in NLP, and effectively addressing these long-tail behaviors is frequently a challenge for the learner. To investigate the impact of a long-tailed word distributions, we first randomly select the latent variables z1, . . . , z1000 uniformly at random as before. We then use them to build two distinct training sets. We build a large training set by generating nspl = 500 training points per latent variable and a small training set by generating nspl = 5 training points per latent variable. We use the 1/i power law distribution when sampling words from concepts in both cases. We then train h and h on both training sets and evaluate their generalization performance. When trained on the large training set, both are 100% accurate at test time (as they should be the large training set has 500, 000 total samples). A significant difference emerges between h and h when trained on the small training set. The network h achieves a test accuracy of 45% while h remains 100% accurate. We once again visualize the weights of each network to study the relationship between generalization and collapse. Figure 4(a) depicts the weights of h W,U (via dimensionality reduction and color coding) after training on the large training set. The word embeddings are on the left sub-panel and the linear weights uk, ℓon the right sub-panel. Words that belong to the same concept still receive (a) h trained on the large training set. Test acc. = 100% (b) h trained on the small training set. Test acc. = 45% (c) h trained on the small training set. Test acc. = 100% Figure 4: Visualization of matrices W (left in each subfigure) and U (right in each subfigure) Published as a conference paper at ICLR 2024 embeddings that are aligned, however, the magnitude of these embeddings depends upon word frequency. The most frequent words in a concept (e.g. potato) have the largest embeddings while the least frequent words (e.g. achojcha) have the smallest embeddings. In other words, we observe directional collapse of the embeddings, but the magnitudes do not collapse. In contrast, the linear weights uk,ℓmostly concentrate around three well-defined, equi-angular locations; they collapse in both direction and magnitude. A major contribution of our work (c.f. theorem 2 in the next section) is a theoretical insight that explains the configurations observed in figure 4(a), and in particular, explains why the magnitudes of word embeddings depend on their frequencies. Figure 4(b) illustrates the weights of h W,U after training on the small training set. While the word embeddings exhibit a similar pattern as in figure 4(a), the linear weights uk,ℓremain dispersed and fail to collapse. This leads to poor generalization performance (45% accuracy at test time). To summarize, when the training set is large, the linear weights uk,ℓcollapse correctly and the network h W,U generalizes well. When the training set is small the linear weights fail to collapse, and the network fails to generalize. This phenomenon can be attributed to the long-tailed nature of the word distribution. To see this, say that zk = [ veggie, dairy, veggie, . . . , meat, dairy ] represents the kth latent variable for the sake of concreteness. With only nspl = 5 samples for this latent variable, we might end up in a situation where the 5 words selected to represent the first occurrence of the veggie concept have very different frequencies than the five words selected to represent the third occurrence of the veggie concept. Since word embeddings have magnitudes that depend on their frequencies, this will result in a serious imbalance between the vectors uk,1 and uk,3 that code for the first and third occurrence of the veggie concept. This leads to two vectors uk,1, uk,3 that code for the same concept but have different magnitudes (as seen on figure 4(b)), so features do not properly collapse. This imbalance results from the noise introduced by sampling only 5 training points per latent variable. Indeed, if nspl = 500 then each occurrence of the veggie concept will exhibit a similar mix of frequent and rare words, uk, 1 and uk, 3 will have roughly same magnitude, and full collapse will take place (c.f. figure 4(a)). Finally, the poor generalization ability of h W,U when the training set is small really stems from the long-tailed nature of the word distribution. The failure mechanism occurs due to the relatively balanced mix of rare and frequent words that occurs with long-tailed data. If the data were dominated by a few very frequent words, then all rare words combined would just contribute small perturbations and would not adversely affect performance. We conclude this section by examining the weights of the network h W,U after training on the small training set. The left panel of figure 4(c) provides a visualization of the word embeddings after the Layer Norm module. These word representations collapse both in direction and magnitude; they do not depend on word frequency since the Layer Norm forces vectors to have identical magnitude. The right panel of figure 4(c) depicts the linear weights uk,ℓand shows that they properly collapse. As a consequence, h W,U generalizes perfectly (100% accurate) even with only nspl = 5 sample per class. Normalization plays a crucial role by ensuring that word representations do not depend upon word frequency. In turn, this prevents the undesired mechanism that causes h W,U to have uncollapsed linear weights uk,ℓwhen trained on the small training set. Theorem 3 in the next section proves the correctness of this picture. The weights of the network h collapse to the frequency independent configuration of figure 4(c) in the large sample limit. Our main contributions consist in three theorems. In theorem 1 we prove that the weights of the network h W,U collapse into the configurations depicted on figure 3 when words have identical frequencies. In theorem 2 we provide theoretical justification of the fact that, when words have distinct frequencies, the word embeddings of h W,U must depend on frequency in the manner that figure 4(a) illustrates. Finally, in theorem 3 we show that the weights of the network h W,U exhibit full collapse even when words have distinct frequencies. Each of these theorems hold in the large nspl limit and under some symmetry assumptions on the latent variables. All proofs are in the appendix. Notation. The set of concepts, which up to now was C = {veggie, dairy, meat}, will be represented in this section by the more abstract C = {1, . . . , nc}. We let sc := nw/nc denote the number of words per concept, and represent the vocabulary by V = (α, β) N2 : 1 α nc and 1 β sc Published as a conference paper at ICLR 2024 So elements of V are tuples of the form (α, β) with 1 α nc and 1 β sc, and we think of the tuple (α, β) as representing the βth word of the αth concept. Each concept α C comes equipped with a probability distribution pα : {1, . . . , sc} [0, 1] over the words within it, so that pα(β) is the probability of selecting the βth word when sampling out of the αth concept. For simplicity we assume that the word distributions within each concept follow identical laws, so that pα(β) = µβ for all (α, β) V for some positive scalars µβ > 0 that sum to 1. We think of µβ as being the frequency of word (α, β) in the vocabulary. For example, choosing µβ = 1/sc gives uniform word distributions while µβ 1/β corresponds to Zipf s law. We use X := VL to denote the data space and Z := CL to denote the latent space. The elements of the data space X correspond to sequences x = [(α1, β1), . . . , (αL, βL)] of L words, while elements of the latent space Z correspond to sequences z = [α1, . . . , αL] of L concepts. For a given latent variable z we write x Dz to indicate that the data point x was generated by z (formally Dz : X [0, 1] is a probability distribution). Word embeddings, Layer Norm, and word representations. We use w(α,β) Rd to denote the embedding of word (α, β) V. The collection of all w(α,β) determines the columns of the matrix W Rd nw. These embeddings feed into a Layer Norm module without learnable parameters: φ(v) = v mean(v)1d σ(v) where mean(v) = 1 d Pd i=1 vi and σ2(v) = 1 d Pd i=1 vi mean(v) 2. So the Layer Norm module converts a word embedding w(α, β) Rd into a vector φ(w(α, β)) Rd, and we call this vector a word representation. Equiangular vectors. We call a collection of nc vectors f1, . . . , fnc Rd equiangular if α=1 fα = 0 and fα, fα = 1 if α = α 1/(nc 1) otherwise (4) hold for all possible pairs α, α [nc] of concepts. For example, three vectors f1, f2, f3 R100 are equiangular exactly when they have unit norms, live in a two dimensional subspace of R100, and form the vertices of an equilateral triangle in this subspace. This example exactly corresponds to the configurations in figure 3 and 4 (up to a scaling factor). Similarly, four vectors f1, f2, f3, f4 R100 are equiangular when they have unit norms and form the vertices of a regular tetrahedron. We will sometimes require f1, . . . , fnc Rd to also satisfy fα, 1d = 0 for all α [nc], in which case we say f1, . . . , fnc Rd form a collection of mean-zero equiangular vectors. Collapse configurations. Our empirical investigations reveal two distinct candidate solutions for the weights (W, U) of the network h W, U and h W, U. We therefore isolate each of these possible candidates as a definition before turning to the statements of our main theorems. We begin by defining the type of collapse observed when training the network h W, U with uniform word distributions (see figure 3 for a visual illustration of this type of collapse). Definition 1 (Type-I Collapse). The weights (W, U) of the network h W,U form a type-I collapse configuration if and only if the conditions i) There exists c 0 so that w(α,β) = c fα for all (α, β) V. ii) There exists c 0 so that uk,ℓ= c fα for all (k, ℓ) satisfying zk,ℓ= α and all α C. hold for some collection f1, . . . , fnc Rd of equiangular vectors. Recall that the network h W, U exhibits collapse as well, up to the fact that the word representations φ(wα, β) collapse rather than the word embeddings themselves. Additionally, the Layer Norm also fixes the magnitude of the word representations. We isolate these differences in the next definition. Definition 2 (Type-II Collapse). The weights (W, U) of the network h W,U form a type-II collapse configuration if and only if the conditions i) φ(w(α,β)) = d fα for all (α, β) V. ii) There exists c 0 so that uk,ℓ= c fα for all (k, ℓ) satisfying zk,ℓ= α and all α C. hold for some collection f1, . . . , fnc Rd of mean-zero equiangular vectors. Published as a conference paper at ICLR 2024 Finally, when training the network h W, U with non-uniform word distributions (c.f. figure 4(a)) we observe collapse in the direction of the word embeddings w(α, β) but their magnitudes depend upon word frequency. We therefore isolate this final observation as Definition 3 (Type-III Collapse). The weights (W, U) of the network h W,U form a type-III collapse configuration if and only if i) There exists positive scalars rβ 0 so that w(α, β) = rβ fα for all (α, β) V. ii) There exists c 0 so that uk,ℓ= c fα for all (k, ℓ) satisfying zk,ℓ= α and all α C. hold for some collection f1, . . . , fnc Rd of equiangular vectors. In a type-III collapse we allow the word embedding w(α, β) to have a frequency-dependent magnitude rβ while in type-I collapse we force all embeddings to have the same magnitude; this makes type-I collapse a special case of type-III collapse, but not vice-versa. 3.1 PROVING COLLAPSE Our first result proves that the words embeddings w(α,β) and linear weights uk,ℓexhibit type-I collapse in an appropriate large-sample limit. When turning from experiment (c.f. figure 3) to theory we study the true risk R(W, U) = 1 k=1 E x Dzk h ℓ(h W,U(x), k) i + λ 2 W 2 F + U 2 F (5) rather than the empirical risk Remp(W, U) and place a symmetry assumption on the latent variables. Assumption 1 (Latent Symmetry). For every k [K], r [L], ℓ [L], and α [nc] the identities n k [K] : dist(zk, zk ) = r and zk ,ℓ= α o = |Z| L 1 r (nc 1)r if zk,ℓ= α K |Z| L 1 r 1 (nc 1)r 1 if zk,ℓ = α (6) hold, with dist(zk, zk ) denoting the Hamming distance between a pair (zk, zk ) of latent variables. With this assumption in hand we may state our first main result Theorem 1 (Full Collapse of h). Assume uniform sampling µβ = 1/sc for each word distribution. Let τ 0 denote the unique minimizer of the strictly convex function H(t) := 1 + (nc 1)e ηt L + λt where η = nc nc 1 1 nw KL. Assume z1, . . . , z K are mutually distinct and satisfy the symmetry assumption 1. Then any (W, U) in a type-I collapse configuration with constants c = p τ/nw and c = p τ/(KL) is a global minimizer of (5). We also prove two strengthenings of this theorem in the appendix. First, under an additional technical assumption on the latent variables z1, . . . , z K we prove its converse; any (W, U) that minimizes (5) must be in a type-I collapse configuration (with the same constants c, c ). This additional assumption is mild but technical, so we state it in appendix C. We also prove that if d > nw then R(W, U) does not have spurious local minimizers; all local minimizers are global (see appendix H). The symmetry assumption, while odd at a first glance, is both needed and natural. Indeed, a type-I collapse configuration is highly symmetric and perfectly homogeneous. We therefore expect that such configurations could only solve an analogously symmetric and homogeneous optimization problem. In our case this means using the true risk (5) rather than the empirical risk (2), and imposing that the latent variables satisfy the symmetry assumption. This assumption means that all latent variables play interchangeable roles, or at an intuitive level, that there is no preferred latent variable. To understand this better, consider the extreme case K = n L c and {z1, . . . , z K} = Z, meaning that all latent variables in Z are involved in the task. The identity (6) then holds by simple combinatorics. We may therefore think of (6) as an equality that holds in the large K limit, so it is neither impossible nor unnatural. We refer to appendix C for more discussion about assumption 1. While theorem 1 proves global optimality of type-I collapse configurations in the limit of large nspl and large K, these solutions still provide valuable predictions when K and nspl have small to Published as a conference paper at ICLR 2024 moderate values. For example, in the setting of figure 3 (nspl = 5 and K = 1000) the theorem predicts that word embeddings should have a norm c = p τ/nw = 1.42214 (with τ obtained by minimizing H(t) numerically). By experiment we find that, on average, word embeddings have norm 1.41 with standard deviation 0.13. To take another example, when K = 50 and nspl = 100 (and keeping nc = 3, nw = 1200, L = 15) the theorem predicts that words embeddings should have norm 0.61602. This compares well against the values 0.61 0.06 observed in experiments. The idealized solutions of the theorem capture their empirical counterparts very well. For non-uniform µβ we expect h W, U to exhibit type-III collapse rather than type-I collapse. Additionally, in our long-tail experiments, we observe that frequent words (i.e. large µβ) receive large embeddings. We now prove that this is the case in our next theorem. To state it, consider the following system of sc + 1 equations c nc 1 + exp nc nc 1c rβ = µβ for all 1 β sc (7) Psc β=1 rβ c 2 = Ln L 1 c (8) for the unknowns (c, r1, . . . , rsc). If the regularization parameter λ is small enough, namely λ2 < L n L+1 c Psc β=1 µ2 β, then (7) (8) has a unique solution. This solution defines the magnitudes of the word embeddings. The left hand side of (7) is an increasing function of rβ, so µβ < µβ implies rβ < rβ and more frequent words receive larger embeddings. Theorem 2 (Directional Collapse of h). Assume λ2 < (L/n L+1 c ) Psc β=1 µ2 β, K = n L c and {z1, . . . , z K} = Z. Suppose (W, U) is in a type-III collapse configuration for some constants (c, r1, . . . , rsc). Then (W, U) is a critical point of the true risk (5) if and only if (c, r1, . . . , rsc) solve the system (7) (8). Essentially this theorem shows that word embeddings must depend on word frequency and so feature collapse fails. Even in the fully-sampled case K = n L c and {z1, . . . , z K} = Z a network exhibiting type-I collapse is never critical if the word distributions are non-uniform. While we conjecture global optimality of the solutions in theorem 2 under appropriate symmetry assumptions, we have no proof of this yet. The bound on λ is the natural one for theorem 2, for if λ is too large the trivial solution (W, U) = (0, 0) is the only one. In our experiments, λ satisfies this bound. Our final theorem completes the picture; it shows that normalization restores global optimality of fully-collapsed configurations. For the network h W, U with Layer Norm, we use the appropriate limit R (W, U) = 1 k=1 E x Dzk h ℓ(h W,U(x), k) i + λ 2 U 2 F (9) of the associated empirical risk and place no assumptions on the sampling distribution. Theorem 3 (Full Collapse of h ). Assume the non-degenerate condition µβ > 0 holds. Let τ 0 denote the unique minimizer of the strictly convex function H (t) = 1 + (nc 1)e η t L + λ 2 t2 where η = nc nc 1 1 KL/d. Assume z1, . . . , z K are mutually distinct and satisfy assumption 1. Then any (W, U) in a type-II collapse configuration with constant c = τ/ KL is a global minimizer of (9). As for theorem 1, we prove the converse under an additional technical assumption on the latent variables. Any (W, U) that minimizes (9) must be in a type-II collapse configuration with c = τ/ KL. The proof and exact statement can be found in section F of the appendix. 4 ADDITIONAL EXPERIMENTS Our theoretical investigation of feature collapse uses a simple synthetic data model and a basic network. These simplifications allow us to rigorously prove that feature collapse occurs in this setting. The first section of the appendix provides preliminary evidence that the feature collapse phenomenon also occurs in more complex settings, which are beyond the reach of our current analytical tools. In particular, we experimentally observe feature collapse in more complex data models that involve a deeper hierarchy of latent structures. We also investigate the feature collapse phenomenon in transformer architectures in both a classification setup and the usual next-word-prediction setup. Published as a conference paper at ICLR 2024 ACKNOWLEDGMENT Xavier Bresson is supported by NUS Grant ID R-252-000-B97-133. The authors would like to express their gratitude to the reviewers for their feedback, which has improved the clarity and contribution of the paper. Emmanuel Abbe, Enric Boix Adsera, and Theodor Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks. In Conference on Learning Theory, pp. 4782 4887. PMLR, 2022. Zeyuan Allen-Zhu and Yuanzhi Li. Physics of language models: Part 1, context-free grammar. ar Xiv preprint ar Xiv:2305.13673, 2023. Jimmy Ba, Murat A Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, and Greg Yang. Highdimensional asymptotics of feature learning: How one gradient step improves the representation. Advances in Neural Information Processing Systems, 35:37932 37946, 2022. Alberto Bietti, Joan Bruna, Clayton Sanford, and Min Jae Song. Learning single-index models with shallow neural networks. Advances in Neural Information Processing Systems, 35:9768 9783, 2022. Noam Chomsky. Three models for the description of language. IRE Transactions on information theory, 2(3):113 124, 1956. Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations with gradient descent. In Conference on Learning Theory, pp. 5413 5452. PMLR, 2022. Cong Fang, Hangfeng He, Qi Long, and Weijie J Su. Exploring deep neural networks via layerpeeled model: Minority collapse in imbalanced training. Proceedings of the National Academy of Sciences, 118(43):e2103091118, 2021. Vitaly Feldman. Does learning require memorization? a short tale about a long tail. In Proceedings of the 52nd Annual ACM SIGACT Symposium on Theory of Computing, pp. 954 959, 2020. Vitaly Feldman and Chiyuan Zhang. What neural networks memorize and why: Discovering the long tail via influence estimation. Advances in Neural Information Processing Systems, 33:2881 2891, 2020. Wenlong Ji, Yiping Lu, Yiliang Zhang, Zhun Deng, and Weijie J Su. An unconstrained layer-peeled perspective on neural collapse. ar Xiv preprint ar Xiv:2110.02796, 2021. Yoon Kim, Chris Dyer, and Alexander M Rush. Compound probabilistic context-free grammars for grammar induction. ar Xiv preprint ar Xiv:1906.10225, 2019. Thomas Laurent and James Brecht. Deep linear networks with arbitrary loss: All local minima are global. In International conference on machine learning, pp. 2902 2907. PMLR, 2018. Hong Liu, Sang Michael Xie, Zhiyuan Li, and Tengyu Ma. Same pre-training loss, better downstream: Implicit bias matters for language models. In International Conference on Machine Learning, pp. 22188 22214. PMLR, 2023. Ziwei Liu, Zhongqi Miao, Xiaohang Zhan, Jiayun Wang, Boqing Gong, and Stella X Yu. Largescale long-tailed recognition in an open world. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2537 2546, 2019. Jianfeng Lu and Stefan Steinerberger. Neural collapse with cross-entropy loss. ar Xiv preprint ar Xiv:2012.08465, 2020. Dustin G Mixon, Hans Parshall, and Jianzong Pi. Neural collapse with unconstrained features. ar Xiv preprint ar Xiv:2011.11619, 2020. Published as a conference paper at ICLR 2024 Alireza Mousavi-Hosseini, Sejun Park, Manuela Girotti, Ioannis Mitliagkas, and Murat A Erdogdu. Neural networks efficiently learn low-dimensional representations with sgd. ar Xiv preprint ar Xiv:2209.14863, 2022. Vardan Papyan, XY Han, and David L Donoho. Prevalence of neural collapse during the terminal phase of deep learning training. Proceedings of the National Academy of Sciences, 117(40): 24652 24663, 2020. Suzanna Parkinson, Greg Ongie, and Rebecca Willett. Linear neural network layers promote learning single-and multiple-index models. ar Xiv preprint ar Xiv:2305.15598, 2023. Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. Open AI blog, 1(8):9, 2019. Ruslan Salakhutdinov, Antonio Torralba, and Josh Tenenbaum. Learning to share visual appearance for multiclass object detection. In CVPR 2011, pp. 1481 1488. IEEE, 2011. Tom Tirer and Joan Bruna. Extended unconstrained features model for exploring deep neural collapse. In International Conference on Machine Learning, pp. 21478 21505. PMLR, 2022. Stephan Wojtowytsch et al. On the emergence of simplex symmetry in the final and penultimate layers of neural network classifiers. ar Xiv preprint ar Xiv:2012.05420, 2020. Jinxin Zhou, Xiao Li, Tianyu Ding, Chong You, Qing Qu, and Zhihui Zhu. On the optimization landscape of neural collapse under mse loss: Global optimality with unconstrained features. In International Conference on Machine Learning, pp. 27179 27202. PMLR, 2022. Xiangxin Zhu, Dragomir Anguelov, and Deva Ramanan. Capturing long-tail distributions of object subcategories. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 915 922, 2014. Zhihui Zhu, Tianyu Ding, Jinxin Zhou, Xiao Li, Chong You, Jeremias Sulam, and Qing Qu. A geometric analysis of neural collapse with unconstrained features. Advances in Neural Information Processing Systems, 34:29820 29834, 2021. George K Zipf. The psycho-biology of language. 1935. Published as a conference paper at ICLR 2024 In section A we conduct an empirical investigation of the feature collapse phenomenon in settings beyond the reach of our current analytical tools. The remaining sections are all devoted to the proofs of the three theorems that constitute the main results of the paper. Section B provides formulas for the networks h W,U and h W,U depicted on figure 2 of the main paper, and formula for the distribution Dzk : X [0, 1] underlying the data model depicted on figure 1 of the main paper. We also use this section to introduce various notations that our proofs will rely on. Section C is devoted to the symmetry assumptions that we impose on the latent variables. We start with an in depth discussion of assumption 1. from the main paper. This assumption is required for theorem 1 and 3 to hold. We then present and discuss an additional technical assumption on the latent variables (c.f. assumption B) that we will use to prove the converse of theorems 1 and 3. Whereas sections B and C are essentially devoted to notations and discussions, most of the analysis occurs in section D, E, F and G. We start by deriving a sharp lower bound for the unregularized risk in section D. Theorem 1 from the main paper, as well as its converse, are proven in section E. Theorem 3 and its converse are proven in section F. Finally we prove theorem 2 in section G. We conclude this appendix by proving in section H that if d > min(nw, KL), then the risk associated to the network h W,U does not have spurious local minimizers; all local minimizers are global. This proof follows the same strategy that was used in Zhu et al. (2021). A FURTHER EMPIRICAL INVESTIGATIONS In this section, show empirically that the feature collapse phenomenon is not limited to the simple controlled setting where we were able to prove it. In particular, we show that feature collapse occurs in the absence of weight decay and when the Layer Norm has learnable parameters. We also show that feature collapse occurs in transformer architectures, and also when the classification task is replaced with a language modeling task. Finally, we show that feature collapse occurs in data models involving a deeper hierarchy of latent structures, such as a Context Free Grammar (CFG). A.1 EXPERIMENTS WITHOUT WEIGHT DECAY We start by reproducing the experiments depicted in Figure 3 and 4 but without weight decay. Specifically, we set λ = 0 in equations (2) and (3). All other parameters defining the networks and data model remain the same. To train the networks, we perform 5 million iterations of stochastic gradient descent with a batch size of 100 and a learning rate of 0.1. After training, the empirical losses for all networks are below 10 4. The outcomes of these experiments are depicted in Figures 5 and 6. These figures are virtually identical to Figures 3 and 4 in the main paper. In other words, the absence or presence of weight decay does not affect our main qualitative findings. (a) h trained on the large training set. Test acc. = 100% (b) h trained on the small training set. Test acc. = 66% (c) h trained on the small training set. Test acc. = 100% Figure 6: Same experiments as in Figure 4 but with no weight decay. Note that the result are qualitatively similar. Published as a conference paper at ICLR 2024 A.2 LAYERNORM WITH LEARNABLE PARAMETERS Figure 7: Same experiment as in Figure 4(c) but with learnable weight in the Layer Norm When the Layer Norm module has no learnable weights, the word embeddings must lie on a sphere of constant radius. This constraint aids collapse. A natural question is whether collapse still occurs when the Layer Norm has learnable parameters. To answer this question, we reproduce the experiment that corresponds to Figure 4(c) but allow the Layer Norm module to have learnable weights. The result of this experiment is depicted on Figure 7, and one can clearly observe that feature collapse does occur in this setting as well. A.3 CLASSIFICATION EXPERIMENTS WITH TRANSFORMERS In this set of experiments, we train a transformer on the classification task depicted on figure 1. The transformer has 2 layers, 8 heads, 512 dimensions, and we use absolute positional embeddings (as in GPT-2). A classification token is appended to each input sentence, and this classification token is used in the last layer to predict the category. The network is trained with Adam W (constant learning rate of 10 4, weight decay of 0.1, β1 = 0.9 and β2 = 0.95) during 3 epochs on a training sets containing 0.5 million sentences. For the data model, we use nc = 3, nw = 1200, L = 15, K = 1000 as in the main paper. In figure 8(a) and 8 (b) we display the word embeddings via dimensionality reduction and color coding. These are the word embeddings obtained before addition of the positional embeddings, and before going through the first transformer layer. Figure 8(a) corresponds to the case in which the words are uniformly distributed, and Figure 8(b) corresponds to the long-tail case. We observe that the word embeddings, in both the uniform and long-tail case, are properly collapsed. A.4 LANGUAGE MODELING EXPERIMENTS WITH TRANSFORMERS In this set of experiments, we train a transformer to predict the next token on sentences generated by the data model depicted on Figure 1. We use the GPT-2 architecture (Radford et al., 2019) with 2 layers, 8 heads, and 512 dimensions. The training set contains 1 million sentences generated by our data model with parameters nc = 3, nw = 1200, L = 15, K = 1000, and with uniform word distributions. We perform a single epoch through the training set and use Adam W with same parameters as above. On figure 8 (c) we display the word embeddings via dimensionality reduction and color coding, and we observe that they are they are properly collapsed. A.5 EXPERIMENTS WITH CONTEXT FREE GRAMMAR The data model presented in the paper extends to one with a deeper hierarchy of latent structures. Recall that words are partitioned into concepts and that the latent variables are sequences of concepts. We can further partitioned the latent variables into meta-concepts and create deeper latent variables that are sequences of meta-concepts . We can iterate this process to obtain a hierarchy of any depth. Such a data model is a particular instance of a Context Free Grammar (Chomsky, 1956), which generates sentences with probabilistic trees and are widely used to understand natural (a) Classification task (uniform word distribution) (b) Classification task (long-tail word distribution) (c) Language modeling task (uniform word distribution) Figure 8: Experiments with transformers. Published as a conference paper at ICLR 2024 language models (e.g. Kim et al. (2019); Allen-Zhu & Li (2023); Liu et al. (2023)). In Figure 9 we provide an illustration of a simple depth 3 context free grammar. We ran experiments with a context free grammar of depth 4, meaning that we have words, concepts, meta-concepts and meta-meta-concepts. We used a deep neural network with Re LU nonlinearities and Layer Norm module at each layer. The architecture of the neural network was chosen to match that of the context free grammar, see Figure 10. In Figure 11 we plot the activations after each of the three hidden layers and readily observe the expected feature collapse phenomenon. All segments of the input sentence that correspond to same concept, meta-concept, or meta-meta-concept receive the same representations in the appropriate layer of the network (layer 1 for concepts, layer 2 for metaconcepts, and layer 3 for meta-meta-concepts). This shows that the feature collapse phenomenon is a general one. Details of the Context Free Grammar: We used the following context free grammar for our experiment. We choose K = 100 categories. Each category generates a sequence of meta-meta-concepts of length 8 by choosing uniformly at random among 5 possible sequences of meta-meta-concepts. Each meta-meta-concept then generates a sequence of meta-concepts of length 8, again by choosing uniformly at random among 5 possible sequences of meta-concepts. Each meta-concept then generates a sequence of concepts of length 8 by choosing among 5 possible sequences of concepts. Finally, each concept generates a sequence of words of length 8 by choosing among 5 possible sequences of words. At level 0 the sequences of words have an overall length of 84 = 4096. β1 β2 β3 β4 β5 β6 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 Sequence of words Sequence of concepts Sequence of meta-concepts Class index in {1,2, ,K} generates generates generates generates generates generates generates generates Figure 9: Probabilistic Context Free Grammar of depth 3. The class index generates a sequence of meta-concepts. Each meta-concept further generates a sequence of concepts. Finally each concept generates a sequence of words. The process by which a token from one level generates a sequence of tokens in the level below is random. For example, the meta-concept γ2 = 5 might generate the sequence of concepts [β4, β5, β6] = [4, 1, 3] with probability 1/3, the sequence [β4, β5, β6] = [2, 5, 3] with probability 1/3, and the sequence [β4, β5, β6] = [3, 5, 5] with probability 1/3. ~b1 ~b2 ~b3 ~b4 ~b5 ~b6 ~a1 ~a2 ~a3 ~a4 ~a5 ~a6 ~a7 ~a8 ~a9 ~a10 ~a11 ~a12 ~a13 ~a14 ~a15 ~a16 ~a17 ~a18 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 Block 1 Block 1 Block 1 Block 1 Block 1 Block 1 Block 2 Block 2 Sequence of words Vectors in Rd (level 0) Vectors in Rd (level 1 ) Vectors in Rd (level 2) Vectors in RK (level 3) Figure 10: Neural network architecture matching the context free grammar from Figure 9. Each block consists in a MLP followed by layer normalization. The vectors are concatenated before being fed to a block. If feature collapse occurs, then two sequences of words [α1, α2, α3] generated by the same concept β1 should have almost identical representation b1 in level 1 of the network. Similarly, two sequences of words [α1, α2, . . . , α9] generated by the same meta-concept γ1 should have almost identical representation c1 in level 2 of the network. Published as a conference paper at ICLR 2024 Level 1 representations color-coded according to concepts Level 2 representations color-coded according to meta-concepts Level 3 representations color-coded according to meta-meta-concepts Figure 11: Results of an experiment ran on a context free grammar of depth 4 (so deeper than the one depicted on Figure 1). At every level, each token generates in the level below a sequence of length 8 chosen uniformly at random among 5 possible sequences. After training the network, we generate 1000 test sequences, feed them to the network, and visualize via PCA the representations obtained at each layer (i.e. we plot the vectors bi, ci and di). The vectors are color coded according to the concept, meta-concept, and meta-meta-concept that generated them, and we keep only representations corresponding to the first 3 concepts, meta-concepts, and meta-meta-concepts. We clearly observe the feature collapse phenomenon. Note that since we are dealing with a network of depth 4, the vectors di are not the output, but the last hidden representations. B PRELIMINARIES AND NOTATIONS B.1 FORMULA FOR THE NEURAL NETWORKS Recall that the vocabulary is the set V = {(α, β) N2 : 1 α nc and 1 β sc}, and that we think of the tuple (α, β) V as representing the βth word of the αth concept. The data space is X = VL, and a sentence x X is a sequence of L words: x = [(α1, β1), . . . , (αL, βL)] 1 αℓ nc and 1 βℓ sc. The two neural networks h, h studied in this work process such a sentence x X in multiple steps: 1. Each word (αℓ, βℓ) of the sentence is encoded into a one-hot vector. 2. These one-hot vectors are multiplied by a matrix W to produce word embeddings that live in a d-dimensional space. 3. Optionally (i.e. in the case of the network h ), these word embeddings go through a Layer Norm module without learnable parameters. 4. The word embeddings are concatenated and then goes through a linear transformation U. We now formalize these 4 steps, and in the process, we set the notations on which we will rely in all our proofs. Step 1: One-hot encoding. Without loss of generality, we choose the following one-hot encoding scheme: word (α, β) V receives the one-hot vector which has a 1 in entry (α 1)sc + β and 0 everywhere else. To formalize this, we define the one-hot encoding function ζ(α, β) = e(α 1)sc+β (10) where ei denotes the ith basis vector of Rnw. The one-hot encoding function ζ can also be applied to a sequence of words. Given a sentence x = [(α1, β1), . . . , (αL, βL)] X we let " | | | ζ(α1, β1) ζ(α2, β2) . . . ζ(αL, βL) | | | and so ζ maps sentences to nw L matrices. Published as a conference paper at ICLR 2024 Step 2: Embedding. The embedding matrix W has nw columns and each of these columns belongs to Rd. Since ζ(α, β) denote the one-hot vector associated to word (α, β) V, we define the embedding of word (α, β) by w(α,β) := W ζ(α, β) Rd. (12) Due to (10), this means that w(α,β) is the jth column of W, where j = (α 1)sc + β. The embedding matrix W can therefore be visualized as follow (for concreteness we choose nc = 3 and nw = 12 as in figure 1 of the main paper): | | | | | | | | | | | | w(1,1) w(1,2) w(1,3) w(1,4) w(2,1) w(2,2) w(2,3) w(2,4) w(3,1) w(3,2) w(3,3) w(3,4) | | | | | | | | | | | | Embeddings of the words in the 1st concept. | {z } Embeddings of the words in the 2nd concept. | {z } Embeddings of the words in the 3rd concept. Given a sentence x = [(α1, β1), . . . , (αL, βL)] X, appealing to (11) and (12), we find that | | | w(α1,β1) w(α2,β2) w(αL,βL) | | | and therefore Wζ(x) is the matrix that contains the d-dimensional embeddings of the words that constitute the sentence x X. Step 3: Layer Norm. Recall from the main paper that the Layer Norm function φ : Rd Rd is defined by φ(v) = v mean(v)1d σ(v) where mean(v) = 1 i=1 vi and σ2(v) = 1 vi mean(v) 2, We will often apply this function column-wise to a matrix. For example if V is the d m matrix " | | | v1 v2 vm | | | , then φ(V ) = " | | | φ(v1) φ(v2) φ(vm) | | | Applying φ to (13) gives φ w(α1,β1) φ w(α2,β2) φ w(αL,βL) and so φ (Wζ(x)) contains the word representations of the words from the input sentence (recall that by word representations we mean the word embeddings after the Layer Norm). Step 4: Linear Transformation. Recall from the main paper that u1,1 u1,2 u1,L u2,1 u2,2 u2,L ... ... ... u K,1 u K,2 u K,L where each vector uk,ℓbelongs to Rd. The neural networks h W,U and h W,U are then given by the formula h W,U(x) = U Vec [Wζ(x)] (16) h W,U(x) = U Vec h φ Wζ(x) i (17) Published as a conference paper at ICLR 2024 where Vec : Rd L Rd L is the function that takes as input a d L matrix and flatten it out into a vector with d L entries (with the first column filling the first d entries of the vector, the second column filling the next d entries, and so forth). It will prove convenient to gather the L vectors uk,ℓ that constitute the kth row of U into the matrix " | | | uk,1 uk,2 uk,L | | | With this notation, we have the following alternative expressions for the networks h W,U and h W,U D ˆU1 , W ζ(x) E F D ˆU2 , W ζ(x) E ... D ˆUK , Wζ(x) E and h W,U(x) = D ˆU1 , φ W ζ(x) E F D ˆU2 , φ W ζ(x) E ... D ˆUK , φ Wζ(x) E where , F denote the Frobenius inner product between matrices (see next subsection for a definition). Finally, we use ˆU to denote the matrix obtained by concatenating the matrices ˆU1, . . . , ˆUK, that is ˆU := ˆU1 ˆU2 ˆUK Rd KL (20) The matrix ˆU, which is nothing but a reshaped version of the original weight matrix U RK Ld, will play a crucial role in our analysis. B.2 BASIC PROPERTIES OF THE FROBENIUS INNER PRODUCT We recall that the Frobenius inner product between two matrices A, B Rm n is defined by j=1 Aij Bij and that the Frobenius norm of a matrix A Rm n is given by A F = p A, A F . In the course of our proofs, we will constantly appeal to the following property of the Frobenius inner product, so we state it in a lemma once and for all. Lemma A. Suppose A Rm n, B Rm r and C Rr n. Then A, BC F = BT A, C F and A, BC F = ACT , B Proof. The Frobenius inner product can be expressed as A, B F = Tr(AT B), and so we have A, BC F = Tr(AT BC) = Tr BT A T C = BT A, C Using the cyclic property of the trace, we also get A, BC F = Tr(AT BC) = Tr(CAT B) = Tr ACT T B = ACT , B B.3 THE TASK, THE DATA MODEL, AND THE DISTRIBUTION Dzk Recall that C = {1, . . . , nc} represents the set of concepts, and that Z = CL is the latent space. We aim to study a classification task in which the K classes are defined by K latent variables z1, . . . , zk Z Published as a conference paper at ICLR 2024 We write x Dzk to indicate that the sentence x X is generated by the latent variable zk Z (see figure 1 of the main paper for a visual illustration). Formally, Dzk is a probability distribution on the data space X, and we now give the formula for its p.d.f. First, recall that µβ > 0 stands for the probability of sampling the βth word of the αth concept. Let us denote the kth latent variable by zk = [ zk,1 , zk,2 , . . . , zk,L ] Z where 1 zk,ℓ nc. The probability of sampling the sentence x = [ (α1, β1) , (α2, β2) . . . , (αL, βL) ] X according to Dzk is then given by the formula Dzk ({x}) = ℓ=1 1{αℓ=zk,ℓ} µβℓ Note that Dzk ({x}) > 0 if and only if [zk,1, . . . , zk,L] = [α1, . . . , αL]. So a sentence x has a nonzero probability of being generated by the latent variable zk only if its words match the concepts in zk. If this is the case, then the probability of sampling x according to Dzk is simply given by the product of the frequencies of the words contained in x. We use Xk to denote the support of the distribution Dzk, that is Xk := {x X : Dzk(x) > 0} and we note that if the latent variables z1, . . . , z K are mutually distinct, then Xj Xk = for all j = k. Since the K latent variables define the K classes of our classification problem, we may alternatively define Xk by Xk = {x X : x belongs to the kth category} To each latent variable zk = [ zk,1 , zk,2 , . . . , zk,L ] we associate a matrix " | | | ezk,1 ezk,2 ezk,L | | | In other words, the matrix Zk provides a one-hot representation of the concepts contained in the latent variable zk. Concatenating the matrices Z1, . . . , ZK gives the matrix Z = [Z1 Z2 ZK] Rnc KL (22) which is reminiscent of the matrix ˆU defined by (20). We encode the way words are partitioned into concepts into a partition matrix P Rnc nw. For example, if we have 12 words and 3 concepts, then the partition matrix is "1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 Rnc nw, (23) indicating that the first 4 words belong to concept 1, the next 4 words belongs to concept 2, and so forth. Formally, recalling that ζ(α, β) is the the one-hot encoding of word (α, β) V, the matrix P is defined the relationship P ζ(α, β) = eα for all (α, β) V. (24) Importantly, note that the matrix P maps datapoints to their associated latent variables. Indeed, if x = [(α1, β1), . . . , (αL, βL)] is generated by the latent variable zk (meaning that x Xk), then we have that " | | | ζ(α1, β1) ζ(α2, β2) . . . ζ(αL, βL) | | | " | | | eα1 eα2 . . . eαL | | | where the last equality is due to definition (21) of the matrix Zk. Another important matrix for our analysis will be the matrix Q Rnc nw. In the concrete case where we have 12 words and 3 concepts, this matrix takes the form "µ1 µ2 µ3 µ4 0 0 0 0 0 0 0 0 0 0 0 0 µ1 µ2 µ3 µ4 0 0 0 0 0 0 0 0 0 0 0 0 µ1 µ2 µ3 µ4 Rnc nw (26) and, in general, it is defined by the relationship Q ζ(α, β) = µβ eα for all (α, β) V. (27) Published as a conference paper at ICLR 2024 C SYMMETRY ASSUMPTIONS ON THE LATENT VARIABLES In subsection C.1 we provide an in depth discussion of the symmetry assumption required for theorems 1 and 3 to hold. In subsection C.2 we present and discuss the assumption that will be needed to prove the converse of theorems 1 and 3. C.1 SYMMETRY ASSUMPTION NEEDED FOR THEOREM 1 AND 3 To better understand the symmetry assumption 1 from the main paper, let us start by considering the extreme case K = n L c and {z1, z2, . . . , z K} = Z, (28) meaning that z1, . . . , z K are mutually distinct and represent all possible latent variables in Z. In this case, we easily obtain the formula n j [K] : dist(zj, z1) = r and zj,L = z1,L o = L 1 r (nc 1)r (29) where dist(zj, z1) is the Hamming distance between the latent variables zj and z1. To see this, note that the left side of (29) counts the number of latent variables zj that differs from z1 at r locations and agrees with z1 at the last location ℓ= L. This number is clearly equal to the right side of (29) since we need to choose r positions out of the first L 1 positions, and then, for each chosen position ℓ, we need to choose a concept out of the nc 1 concepts that differs from z1,ℓ. A similar reasoning shows that, if z1,L = α, then n j [K] : dist(zj, z1) = r and zj,L = α o = L 1 r 1 (nc 1)r 1 (30) where the term L 1 r 1 arises from the fact that we only need to choose r 1 positions, since z1 and zj differ in their last position ℓ= L. Suppose now that the random variables z1, . . . , z K are selected uniformly at random from Z, and say, for the sake of concreteness, that so that z1, . . . , z K represent 20% of all possible latent variables (note that |Z| = n L c ). Then (29) (30) should be replaced by n j [K] : dist(zj, z1) = r and zj,L = z1,L o 1 (nc 1)r (31) n j [K] : dist(zj, z1) = r and zj,L = α o 1 (nc 1)r 1 for α = z1,L (32) where the equality only holds approximatively due to the random choice of the latent variables. In the above example, we chose z1 as our reference latent variables and we froze the concept appearing in position ℓ= L. These choices were clearly arbitrary. In general, when K is large, we have n j [K] : dist(zj, zk) = r and zj,ℓ= α o K n L c L 1 r (nc 1)r if zk,ℓ= α K n L c L 1 r 1 (nc 1)r 1 if zk,ℓ = α (33) and this approximate equality hold for most k [K], r [L], ℓ [L], and α [nc]. The symmetry assumption 1 from the main paper requires (33) to hold not approximatively, but exactly. For convenience we restate below this symmetry assumption: Assumption A (Latent Symmetry). For every k [K], r [L], ℓ [L], and α [nc] the identities n j [K] : dist(zj, zk) = r and zj,ℓ= α o = K n L c L 1 r (nc 1)r if zk,ℓ= α K n L c L 1 r 1 (nc 1)r 1 if zk,ℓ = α (34) Published as a conference paper at ICLR 2024 To be clear, if the latent variables z1, . . . , z K are selected uniformly at random from Z, then they will only approximatively satisfy assumption A. Our analysis, however, is conducted in the idealized case where the latent variables exactly satisfy the symmetry assumption. Specifically, we show that, in the idealized case where assumption A is exactly satisfied, then the weights W and U of the network are given by some explicit analytical formula. Importantly, as it is explained in the main paper, our experiments demonstrate that these idealized analytical formula provide very good approximations for the weights observed in experiments when the latent variables are selected uniformly at random. In the next lemma, we isolate three properties which hold for any latent variables satisfying assumption A. Importantly, when proving collapse, we will only rely on these three properties we will never explicitly need assumption A. We will see shortly that these three properties, in essence, amount to saying that all position ℓ [L] and all concepts α [nc] plays interchangeable roles for the latent variables. There are no preferred ℓor α, and this is exactly what will allow us to derive symmetric analytical solutions. Before stating our lemma, let us define the sphere of radius r centered around the kth latent variable Sr(k) := n j [K] : dist(zj, zk) = r o for r, k [L] (35) With this notation in hand we may now state Lemma B. Suppose the latent variables z1, . . . , z K satisfy the symmetry assumption A. Then z1, . . . , z K satisfies the following properties: (i) |Sr(j)| = |Sr(k)| for all r [L] and all j, k [K]. (ii) The equalities K X nc 1nc1T L and ZZT = KL hold, with Inc denoting the nc nc identity matrix. (iii) There exists θ1, . . . , θL > 0 and matrices A1, . . . , AL Rnc L such that Zk 1 |Sr(k)| j Sr(k) Zj = θr Zk + Ar holds for all r [L], all j [K], and all k [K]. We will prove this lemma shortly, but for now let us start by getting some intuition about properties (i), (ii) and (iii). Property (i) is transparent: it states that all latent variables have the same number of distance-r neighbors . Recalling how matrix Zk was defined (c.f. (21)), we see that the first identity of (ii) is equivalent to |{k [K] : zk,ℓ= α}| = K nc for all ℓ [L] and all α [nc]. (36) This means that the number of latent variables that have concept α in position ℓis equal to K/nc. In other words, each concept is equally represented at each position ℓ. We now turn to the second identity of statement (ii). Recalling the definition (22) of matrix Z, we see that ZZT Rnc nc is a diagonal matrix since each column of Z contains a single nonzero entry. One can also easily see that the αth entry of the diagonal is ZZT α,α = |{(k, ℓ) [K] [L] : zk,ℓ= α}|, which is the total number of times concept α appears in the latent variables. Overall, the identity ZZT = KL nc Inc is therefore equivalent to the statement |{(k, ℓ) [K] [L] : zk,ℓ= α}| = KL nc for all α [nc] and it is therefore a direct consequence of (36). Property (iii) is harder to interpret. Essentially it is a type of mean value property that states that summing over the latent variables which are at distance r of zk gives back zk. We will see that this mean value property plays a key role in our analysis. To conclude this subsection, we prove lemma B. Published as a conference paper at ICLR 2024 Proof of lemma B. We start by proving statement (i). Since Sr(k) = {j [K] : dist(zj, zk) = r}, we clearly have that n j [K] : dist(zj, zk) = r and zj,ℓ= α o (37) We then use identity (34) and Pascal s rule to find |Sr(k)| = (nc 1) K (nc 1)r 1 + K |Z|(nc 1)r L 1 r 1 (nc 1)r (38) which clearly implies that |Sr(k)| = |Sr(j)| for all j, k [K] and all r [L]. We now turn to the first identity of t (ii). As previously mentioned, this identity is equivalent to (36). Choose k such that zk,ℓ = α. Then any any latent variable zj with zj,ℓ= α is at least at a distance 1 of zk and we may write |{j [K] : zj,ℓ= α}| = n j [K] : dist(zj, zk) = r and zj,ℓ= α o (39) (nc 1)r 1 (40) which is equal to K/nc according to the binomial theorem. The second identity of (ii), as mentioned earlier, is a direct consequence of the first identity. We finally turn to statement (iii). Appealing to (38), we find that, n j [K] : dist(zj, zk) = r and zj,ℓ= α o K |Z| L 1 r (nc 1)r K |Z| L r (nc 1)r = if zk,ℓ= α. On the other hand, if zk,ℓ = α, we obtain n j [K] : dist(zj, zk) = r and zj,ℓ= α o K |Z| L 1 r 1 (nc 1)r 1 K |Z| L r (nc 1)r = 1 nc 1 L r = 1 nc 1 r L Fix ℓ [L] and assume that zk,ℓ= α . We then have j Sr(k) ezj,ℓ= 1 |Sr(k)| n j Sr(k) : zj,ℓ= eα o eα n j [K] : dist(zj, zk) = r and zj,ℓ= α o L eα + 1 nc 1 r L L eα 1 nc 1 r L eα + 1 nc 1 r L = 1 nc nc 1 r L eα + 1 nc 1 r L1nc Recalling that zk,ℓ= α , the above implies that ezk,ℓ 1 |Sr(k)| j Sr(k) ezj,ℓ= nc nc 1 r L ezk,ℓ 1 nc 1 r L1nc (41) Published as a conference paper at ICLR 2024 Finally, recalling that " | | | ezk,1 ezk,2 ezk,L | | | we see that (41) can be written in matrix format as Zk 1 |Sr(k)| j Sr(k) Zj = nc nc 1 r LZk 1 nc 1 r L1nc1T L and therefore the scalars θr and the matrices Ar appearing in statement (iii) are given by the formula θr = nc nc 1 r L and Ar = 1 nc 1 r L1nc1T L. C.2 SYMMETRY ASSUMPTION NEEDED FOR THE CONVERSE OF THEOREM 1 AND 3 In this subsection we present the symmetry assumption that will be needed to prove the converse of theorem 1 and 3. This assumption, as we will shortly see, is quite mild and is typically satisfied even for small values of K. For each pair of latent variables (zj, zk) we define the matrix Γ(j,k) := Zj(Zj Zk)T Rnc nc. We also define A := n A Rnc nc : There exists a, b R s.t. A = a Inc + b1nc1T nc o (42) which is the set of matrices whose diagonal entries are equal to some constant and whose offdiagonal entries are equal to some possibly different constant. We may now state our symmetry assumption. Assumption B. Any positive semi-definite matrix A Rnc nc that satisfies D A , Γ(j,k) Γ(j ,k )E F = 0 j, k, j , k [K] s.t. dist(zj, zk) = dist(zj , zk ) (43) must belongs to A. Note that (43) can be viewed as a linear system of equations for the unknown A Rnc nc, with one equation for each quadruplet (j, k, j , k ) satisfying dist(zj, zk) = dist(zj , zk ). To put it differently, each quadruplet (j, k, j , k ) satisfying dist(zj, zk) = dist(zj , zk ) adds one equation to the system, and our assumption requires that we have enough of these equations so that all positive semi-definite solutions are constrained to live in the set A. Since a symmetric matrix has (nc + 1)nc/2 distinct entries, we would expect that (nc + 1)nc/2 quadruplets should be enough to fully determine the matrix. This number of quadruplets is easily achieved even for small values of K. So assumption B is quite mild. The next lemma states that assumption B is satisfied when K = n L c . In light of the above discussion this is not surprising, since the choice K = n L c leads to a system with a number of equations much larger than (nc + 1)nc/2. The proof, however, is instructive: it simply handpicks (nc + 1)nc/2 2 quadruplets to determine the entries of the matrix A. The 2 arises from the fact A is a 2 dimensional subspace, and therefore (nc + 1)nc/2 2 equations are enough to constrain A to be in A. Lemma C. Suppose K = n L c and {z1, . . . , z K} = Z. Then z1, . . . , z K satisfy the symmetry assumption B. Proof. Let A = CT C be a positive semi-definite matrix that solve satisfies (43). We use cα to denote the αth column of C. Since {z1, . . . , z K} = Z, we can find i, j, k [K] such that zi = [2, 1, 1, . . . , 1] Z zj = [3, 1, 1, . . . , 1] Z zk = [4, 1, 1, . . . , 1] Z Published as a conference paper at ICLR 2024 Using lemma A and recalling the definition (21) of the matrix Zk, we get D A , Γ(i,j)E F = CT C, Zi(Zi Zj)T F = C(Zi Zj), CZi F = CZi, CZi F CZj, CZi F = c2, c2 + (L 1) c1, c1 c2, c3 + (L 1) c1, c1 = c2, c2 c2, c3 Similarly we obtain that D A , Γ(i,k)E F = c2, c2 c2, c4 Since dist(zi, zj) = dist(zi, zk) = 1, and since A satisfies (43), we must have D A , Γ(i,j)E F = D A , Γ(i,k)E F which in turn implies that A2,3 = c2, c3 = c2, c4 = A2,4 This argument easily generalizes to show that all off-diagonal entries of the matrix A must be equal to some constant b R. We now take care of the diagonal entries. Since {z1, . . . , z K} = Z, we can find i , j , k [K] such that zi = [1, 1, . . . , 1] Z zj = [2, 2, . . . , 2] Z zk = [3, 3, . . . , 3] Z As before, we compute D A , Γ(i ,j )E F = CZi , CZi F CZj , CZi F = L c1, c1 L c1, c2 = L c1, c1 Lb where we have used the fact that the off diagonal entries are all equal to b. Similarly we obtain D A , Γ(j ,k )E F = L c2, c2 Lb Since dist(zi , zj ) = dist(z j, zk ) = L, we must have D A, Γ(i ,j )E F = D A, Γ(j ,k )E F which implies that A1,1 = A2,2. This argument generalizes to show that all diagonal entries of A are equal. D SHARP LOWER BOUND ON THE UNREGULARIZED RISK In this section we derive a sharp lower bound for the unregularized risk associated with the network h W,U, R0(W, U) := 1 k=1 E x Dzk h ℓ(h W,U(x), k) i , (44) where ℓ: RK R is the cross entropy loss ℓ(y, k) = log exp (yk) PK j=1 exp (yj) The kth entry of the output y = h W,U(x) of the neural network, according to formula (19), is given by yk = D ˆUk , W ζ(x) E Published as a conference paper at ICLR 2024 Recalling that Xk is the support of the distribution Dzk : X [0, 1], we find that the unregularized risk can be expressed as R0(W, U) = 1 x Xk ℓ(h W,U(x), k) Dzk(x) e ˆUk,W ζ(x) F PK j=1 e ˆUj,W ζ(x) F j =k e ˆUk ˆUj,W ζ(x) F where we did the slight abuse of notation of writing Dzk(x) instead of Dzk({x}). Note that a data points x that belongs to class k is correctly classified by the the network h W,U if and only if D ˆUk , W ζ(x) E F > D ˆUj , W ζ(x) E F for all j = k With this in mind, we introduce the following definition: Definition A (Margin). Suppose x Xk. Then the margin between data point x and class j is MW,U(x, j) := D ˆUk ˆUj, Wζ(x) E With this definition in hand, the unregularized risk can conveniently be expressed as R0(W, U) = 1 j =k e MW,U(x,j) Dzk(x) (45) and a data point x Xk is correctly classified by the network if and only if the margins MW,U(x, j) are all strictly positive (for j = k). We then introduce a definition that will play crucial role in our analysis. Definition B (Equimargin Property). If dist(zk, zj) = dist(zk , zj ) = MW,U(x, j) = MW,U(x , j ) x Xk and x Xk then we say that (W, U) satisfies the equimargin property. To put it simply, (W, U) satisfies the equimargin property if the margin between data point x Xk and class j only depends on dist(zk, zj). We denote by E the set of all the weights that satisfy the equimargin property E = {(W, U) : (W, U) satisfies the equimargin property} (46) and by N the set of weights for which the submatrices ˆUk defined by (18) sum to 0, k=1 ˆUk = 0 We will work under the assumption that the latent variables z1, . . . , z K satisfy the symmetry assumption A. According to lemma B, |Sr(k)| then doesn t depend on k, and so we will simply use |Sr| to denote the size of the set Sr(k). Lemma B also states that Zk 1 |Sr(k)| j Sr(k) Zj = θr Zk + Ar for some matrices A1, . . . , AL and some scalars θ1, . . . , θL > 0. We use these scalars to define g(x) := log r=1 |Sr| eθrx/K ! and we note that g : R R is a strictly increasing function. With these definitions in hand we may state the main theorem of this section. Published as a conference paper at ICLR 2024 Theorem D. If the latent variables satisfy the symmetry assumption A, then R0(W, U) = g D ˆU, WQT Z E for all (W, U) N E (49) R0(W, U) > g D ˆU, WQT Z E for all (W, U) N Ec (50) We recall that the matrices ˆU, Q, and Z where defined in section B (c.f. (20), (26) and (22)). The remainder of this section is devoted to the proof of the above theorem. D.1 PROOF OF THE THEOREM We will use two lemmas to prove the theorem. The first one (lemma D below) simply leverages the strict convexity of the various components defining the unregularized risk R0. Recall that if f : Rd R is strictly convex, and if the strictly positive scalars p1, . . . , pn > 0 sum to 1, then i=1 pif(vi) (51) and that equality holds if and only if v1 = v2 = . . . = vn. For this first lemma, the only property we need on the latent variables is that |Sr(k)| = |Sr(j)| = |Sr| for all j, k [K] and all r [L]. Define the quantity NW,U(r) = 1 x Xk MW,U(x, j) Dzk(x) (52) which should be viewed as the averaged margin between data points and classes which are at a distance r of one another. We then have the following lemma: Lemma D. If |Sr(k)| = |Sr(j)| for all j, k [K] and all r [L], then R0(W, U) = log r=1 |Sr|e NW,U(r) ! for all (W, U) E (53) R0(W, U) > log r=1 |Sr|e NW,U(r) ! for all (W, U) / E (54) Proof. Using the strict convexity of the function f : RK 1 R defined by f(v1, . . . , vk 1, vk+1, . . . , v K) = log 1 + X R0(W, U) = 1 j =k e M(x,j) x Xk M(x,j)Dzk (x) and equality holds if and only if, for all k [K], we have that M(x, j) = M(y, j) for all x, y Xk and all j = k (55) We then let M(k, j) = X x Xk M(x, j)Dzk(x) Published as a conference paper at ICLR 2024 and use the strict convexity of the exponential function to obtain j =k e M(k,j) j Sr(k) e M(k,j) r=1 |Sr| 1 |Sr| j Sr(k) e M(k,j) r=1 |Sr|e 1 |Sr| P j Sr(k) M(k,j) ! Moreover, equality holds if and only if, for all k [K] and all r [L], we have that M(k, i) = M(k, j) for all i, j Sr(k) (56) We finally set M(k, r) = 1 |Sr| and use the strict convexity of the function f(v1, . . . , v L) = log 1 + PL r=1 |Sr|evr to get r=1 |Sr|e M(k,r) ! r=1 |Sr|e 1 K PK k=1 M(k,r) ! Moreover equality holds if and only if, for all k [K] and all r [L], we have that M(k, r) = M(k , r) for all k, k [K] and all r [L] (57) Importantly, note that M(k, r) = 1 x Xk MW,U(x, j) Dzk(x) which is precisely how NW,U(r) was defined (c.f. (52)). To conclude the proof, we remark that conditions (55), (56) and (57) are all satisfied if and only if (W, U) satisfies the equi-margin property. We now show that, if assumption A holds, NW,U(r) can be expressed in a simple way. Lemma E. Assume that the latent variables satisfy the symmetry assumption A. Then NW,U(r) = θr D ˆU, WQT Z E F for all (W, U) N (58) Proof. We let x Xk ζ(x)Dzk(x) and note that the averaged margin can be expressed as NW,U(r) = 1 x Xk MW,U(x, j) Dzk(x) D ˆUk ˆUj, Wζ(x) E D ˆUk ˆUj, Xk E D ˆUk, WXk E F 1 K 1 |Sr| D ˆUj, WXk E Published as a conference paper at ICLR 2024 a(r) k,j = 1 if dist(zk, zj) = r 0 otherwise and rewrite the second term in (59) as D ˆUj, WXk E j=1 a(r) k,j Uj, WXk k=1 a(r) j,k D ˆUk, WXj E D ˆUk, WXj E * ˆUk , W 1 |Sr| Combining this with (59) we obtain NW,U(r) = 1 From formula (26), we see that row α of the matrix Q is given by the formula β=1 ζ(α, β) µβ. (61) We then write zk = [α1, . . . , αL] and note that the ℓth column of Xk can be expressed as β=1 ζ(αℓ, β)µβ = QT eαℓ. (62) From this we obtain that and therefore (60) becomes NW,U(r) = 1 * ˆUk , WQT Zk 1 |Sr| j Sr(k) Zj + D ˆUk , WQT θr Zk + Ar E where we have used the identity Zk 1 |Sr| P j Sr(k) Zj = θr Zk +Ar to obtain the second equality. Finally, we use the fact that P k ˆUk = 0 to obtain NW,U(r) = θr D ˆUk , WQT Zk E D ˆU, WQT Z E Combining lemma D and E concludes the proof of theorem D. Published as a conference paper at ICLR 2024 E PROOF OF THEOREM 1 AND ITS CONVERSE In this section we prove theorem 1 under assumption A, and its converse under assumptions A and B. We start by recalling the definition of a type-I collapse configuration. Definition C (Type-I Collapse). The weights (W, U) of the network h W,U form a type-I collapse configuration if and only if the conditions i) There exists c 0 so that w(α,β) = c fα for all (α, β) V. ii) There exists c 0 so that uk,ℓ= c fα for all (k, ℓ) satisfying zk,ℓ= α and all α C. hold for some collection f1, . . . , fnc Rd of equiangular vectors. It will prove convenient to reformulate this definition using matrix notations. Toward this goal, we define equiangular matrices as follow: Definition D. (Equiangular Matrices) A matrix F Rd nc is said to be equiangular if and only if the relations F 1nc = 0 and FT F = nc nc 1 Inc 1 nc 1 1nc1T nc hold. Comparing the above definition with the definition of equiangular vectors provided in the main paper, we easily see that a matrix " | | | f1 f2 fnc | | | is equiangular if and only if its columns f1, . . . , fnc Rd are equiangular. Relations (i) and (ii) defining a type-I collapse configuration can now be expressed in matrix format as W = c F P and ˆU = c F Z for some equiangular matrix F where the matrices Z and P are given by formula (22) and (23). We then let ΩI c := n (W, U) : There exist an equiangular matrix F such that W = c F P and ˆU = c r nw KL F Z o (63) and note that ΩI c is simply the set of weights (W, U) which are in a type-I collapse configuration with constant c and c = c p nw/(KL). We now state the main theorem of this section. Theorem E. Assume uniform sampling µβ = 1/sc for each word distribution. Let τ 0 denote the unique minimizer of the strictly convex function H(t) := log 1 K 1 + (nc 1)e ηt L + λt where η = nc nc 1 1 nw KL and let c = p τ/nw. Then we have the following: (i) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumption A, then ΩI c arg min R (ii) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumptions A and B, then ΩI c = arg min R Note that (i) states that any (W, U) ΩI c is a minimizer of the regularized risk this corresponds to theorem 1 from the main paper. Statement (ii) assert that any minimizer of the regularized risk must belong to ΩI c this is the converse of theorem 1. The remainder of this section is devoted to the proof of theorem E. We will assume uniform sampling µβ = 1/sc for all β [sc] everywhere in this section all lemmas and propositions are proven under this assumption, even when not explicitly stated. Published as a conference paper at ICLR 2024 E.1 THE BILINEAR OPTIMIZATION PROBLEM From theorem D, it is clear that the quantity D ˆU, WQT Z E F plays an important role in our analysis. In this subsection we consider the bilinear optimization problem maximize D ˆU, WQT Z E subject to 1 2 W 2 F + ˆU 2 F = c2 nw (65) where c R is some constant. The following lemma identifies all solutions of this optimization problem. Lemma F. Assume the latent variables satisfy assumption A. Then (W, U) is a solution of the optimization problem (64) (65) if and only if it belongs to the set BI c = n (W, U) : There exist a matrix F Rd nc with F 2 F = nc such that W = c FP and ˆU = c r nw KL FZ o (66) Note that the set BI c is very similar to the set ΩI c that defines type-I collapse configuration (c.f. (92)). In particular, since an equiangular matrix has nc columns of norm 1, it always satisfies F 2 F = nc, and therefore we have the inclusion ΩI c BI c. (67) The remainder of this subsection is devoted to the proof of the lemma. First note that the lemma is trivially true if c = 0, so we may assume c = 0 for the remainder of the proof. Second, we note that since µβ = 1/sc, then the matrices P and Q defined by (23) and (26) are scalar multiple of one another. We may therefore replace the matrix Q appearing in (64) by P, wich leads to maximize D ˆU, WP T Z E subject to 1 2 W 2 F + ˆU 2 F = c2 nw (69) We now show that any (W, ˆU) BI c satisfies the constraint (69) and have objective value equal to sc c2 KLnw. Claim A. If (W, ˆU) BI c, then 1 2 W 2 F + ˆU 2 F = c2 nw and D ˆU, WP T Z E F = c2 sc p Proof. Assume (W, U) BI c. From definition (23) of the matrix P, we have PP T = sc Inc, and therefore W 2 F = c2 FP 2 F = c2 FP, FP F = c2 FPP T , F F = c2 sc F 2 F = c2 sc nc = c2 nw where we have used the fact that sc = nw/nc. Using ZZT = KL nc I from lemma B, we obtain FZ 2 F = FZ, FZ F = FZZT , F As a consequence we have ˆU 2 F = c2 nw KL FZ 2 F = c2 nw and, using PP T = sc Inc one more time, D ˆU, WP T Z E F = c2 r nw KL FZ, FPP T Z KL FZ, FZ F = c2 sc p Published as a conference paper at ICLR 2024 We then prove that W and ˆU must have same Frobenius norm if they solve the optimization problem. Claim B. If (W, U) is a solution of (68) (69), then W 2 F = ˆU 2 F = c2 nw (70) Proof. We prove it by contradiction. Suppose (W, ˆU) is a solution of (64) (65) with W 2 F = ˆU 2 F . Since the average of W 2 F and ˆU 2 F is equal to c2nw > 0 according to the constaint, there must then exists ϵ = 0 such that W 2 F = c2nw + ϵ and ˆU 2 F = c2nw ϵ c2nw c2nw + ϵ W and ˆU0 = c2nw c2nw ϵ ˆU and note that W0 2 F = ˆU0 2 F = c2 nw and therefore (W0, ˆU0) clearly satisfies the constraint. We also have D ˆU0, W0P T Z E c4n2w c4n2w ϵ2 D ˆU, WP T Z E F > D ˆU, WP T Z E since ϵ = 0 and therefore (W, ˆU) can not be a maximizer, which is a contradiction. As a consequence of the above claim, the optimization problem (68) (69) is equivalent to maximize D ˆU, WP T Z E subject to W 2 F = c2 nw and ˆU 2 F = c2 nw (72) We then have Claim C. If (W, ˆU) is a solution of (71) (72), then (W, ˆU) BI c. Note that according to the first claim, all (W, ˆU) BI c have same objective value, and therefore, according to the above claim, they must all be maximizer. As a consequence, proving the above claim will conclude the proof of lemma F. Proof of the claim. Maximizing (71) over ˆU first gives ˆU = c nw WP T Z WP T Z F (73) and therefore the optimization problem (71) (72) reduces to maximize WP T Z 2 F subject to W 2 F = c2 nw Using ZZT = KL nc I from lemma B we then get WP T Z 2 F = WP T Z, WP T Z F = WP T ZZT , WP T nc WP T 2 F and therefore the problem further reduces to maximize WP T 2 F subject to W 2 F = c2 nw The KKT conditions for this optimization problem are WP T P = νW (74) W 2 F = c2 nw (75) Published as a conference paper at ICLR 2024 where ν R is the Lagrange multiplier. Assume that (W, ˆU) is a solution of the original optimization problem (71) (72). Then, according to the above discussion, W must satisfy (74) (75). Right multiplying (74) by P T , and using PP T = sc Inc, gives sc WP T = νWP T So either ν = sc or WP T = 0. The latter is not possible since the choice WP T = 0 leads to an objective value equal to zero in the original optimization problem (71) (72). We must therefore have ν = sc, and equation (74) becomes sc WP T P (76) which can obviously be written as W = c FP by setting F := 1 c sc WP T . Since W satisfies (75) we must have c2 nw = W 2 F = c2 FP 2 F = c2 FP, FP F = c2 FPP T , F F = c2 sc F 2 F , (77) and so F 2 F = nw/sc = nc. According to (73), ˆU bust be a scalar multiple of the matrix WP T Z = (c FP)P T Z = c sc FZ Using the fact that ZZT = KL nc I and F 2 F = nc we then obtain that FZ 2 F = FZ, FZ F = FZZT , F nc F 2 F = KL (78) and so equation (73) becomes ˆU = c nw WP T Z WP T Z F = c nw FZ which concludes the proof. E.2 PROOF OF COLLAPSE Recall that the regularized risk associated with the network h W,U is defined by R(W, U) = R0(W, U) + λ 2 W 2 F + U 2 F (80) and recall that the set of weights in type-I collapse configuration is ΩI c = n (W, U) : There exist an equiangular matrix F such that W = c F P and ˆU = c r nw KL F Z o (81) This subsection is devoted to the proof of the following proposition. Proposition A. We have the following: (i) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumption A, then there exists c R such that ΩI c arg min R (ii) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumptions A and B, then any (W, U) that minimizes R must belong to ΩI c for some c R. Published as a conference paper at ICLR 2024 This proposition states that, under appropriate symmetry assumption, the weights of the network h W,U do collapse into a type-I configuration. This proposition however does not provide the value of the constant c involved in the collapse. Determining this constant will be done in the subsection E.3. We start with a simple lemma. Lemma G. Any global minimizer of (80) must belong to N. Proof. Let (W , U ) be a global minimizer. Define B = 1 K PK k=1 U k and U0 = [U 1 B U 2 B U K B] From the definition of the unregularized risk we have R0(W ; U0) = R0(W ; U ) and therefore 1 K (R(W ; U0) R(W ; U )) = λ U k B 2 F U k 2 F B 2 F 2 B, U k F So B must be equal to zero, otherwise we would have R(W , U0) < R(W , U ). The next lemma bring together the bilinear optimization problem from subsection E.1 and the sharp lower bound on the unregularized risk that we derived in section D. Lemma H. Assume the latent variables satisfy assumption A. Assume also that (W , U ) is a global minimizer of (80) and let c R be such that 1 2 W 2 F + U 2 F = c2 nw. Then the following hold: (i) Any (W, U) that belongs to N E BI c is also a global minimizer of (80). (ii) If N E BI c = , then (W , U ) must belong to N E BI c. Proof. Recall from theorem D that R0(W, U) = g D ˆU, WQT Z E for all (W, U) N E (82) R0(W, U) > g D ˆU, WQT Z E for all (W, U) N Ec (83) We start by proving (i). If (W, U) N E BI c, then we have R0(W , U ) g D ˆU , W QT Z E [because (W , U ) N due to lemma G ] g D ˆU, WQT Z E [because (W, U) BI c and g is increasing] = R0(W, U) [because (W, U) N E ] Since (W, U) BI c we must have 1 2 W 2 F + U 2 F = c2 nc = 1 2 W 2 F + U 2 F . Therefore R(W, U) R(W , U ) and (W, U) is a minimizer. We now prove (ii) by contradiction. Suppose that (W , U ) / N E BI c. This must mean that (W , U ) / E BI c since it clearly belongs to N. If (W , U ) / E then the first inequality in the above computation is strict according to (83). If (W , U ) / BI c then the second inequality is strict because g is strictly increasing. Published as a conference paper at ICLR 2024 The above lemma establishes connections between the set of minimizers of the risk and the set E N BI c. The next two lemmas shows that the set E N BI c is closely related to the set of collapsed configurations ΩI c. In other words we use the set E N BI c as a bridge between the set of minimizers and the set of type-I collapse configurations. Lemma I. If the latent variables satisfy the symmetry assumption A, then ΩI c E N BI c Proof. We already know from (67) that ΩI c BI c. We now show that ΩI c E. Suppose (W, U) ΩI c. Then there exists an equiangular matrix F Rd nc such that W = c F P and ˆU = c F Z where c = c p nw/(KL). Recall from (25) that Pζ(x) = Zk for all x Xk. Consider two latent variables zk = [α1, . . . , αL] and zj = [α 1, . . . , α L] and assume x is generated by zk, meaning that x Xk. We then have MW,U(x, j) = D ˆUk ˆUj, Wζ(x) E F = c c F Zk F Zj, F Pζ(x) F = c c F Zk F Zj, F Zk F D fαℓ fα ℓ, fαℓ E D fα ℓ, fαℓ E Since f1, . . . , fnc are equiangular, we have D fα ℓ, fαℓ E F = L dist(zj, zk) 1 nc 1dist(zj, zk) = L nc nc 1dist(zj, zk). Therefore MW,U(x, j) = cc nc nc 1dist(zj, zk) and it is clear that the margin only depends on dist(zj, zk), and therefore (W, U) satisfies the equimargin property. Finally we show that ΩI c N. Suppose (W, U) ΩI c. From property (ii) of lemma B we have k=1 ˆUk = c K X k=1 F Zk = c K nc F 1nc1T L = 0 where we have used the fact that F 1nc = 0. Lemma J. If the latent variables satisfy assumptions A and B, then ΩI c = E N BI c Published as a conference paper at ICLR 2024 Proof. From the previous lemma we know that ΩI c E N BI c so we need to show that E N BI c ΩI c. Let (W, U) E N BI c. Since (W, U) belongs to BI c, there exists a matrix F Rd nc with F 2 F = nc such that W = c F P and U = c F Z (84) where c = c p nw/(KL). Our goal is to show that F is equiangular, meaning that it satisfies the two relations F 1nc = 0 and F T F = nc nc 1 Inc 1 nc 1 1nc1T nc. (85) The first relation is easily obtained. Indeed, using the fact that (W, U) N together with the identity PK k=1 Zk = K nc 1nc1T L (which hold due to lemma B), we obtain k=0 Uk = c K X k=0 FZk = c K nc F1nc1T L. We then note that the matrix F1nc1T L is the zero matrix if and only if F1nc = 0. We now prove the second equality of (85). Assume that x Xk. Using the fact that Pζ(x) = Zk together with (84), we obtain MW,U(x, j) = D ˆUk ˆUj, Wζ(x) E F = c c F Zk F Zj, F Pζ(x) F = c c F Zk F Zj, F Zk F = c c F T F(Zk Zj), Zk = c c D F T F , Γ(k,j) E We recall that the matrices Γ(k,j) = Zk(Zk Zj)T Rnc nc. are precisely the ones involved in the statement of assumption B. Since (W, U) E, the margins must only depend on the distance between the latent variables. Due to (86), we can be express this as D F T F , Γ(j,k)E F = D F T F , Γ(j ,k )E F j, k, j , k [K] s.t. dist(zj, zk) = dist(zj , zk ) Since the F T F is clearly positive semi-definite, we may then use assumption B to conclude that F T F A. Recalling definition (42) of the set A, we therefore have F T F = a Inc + b 1nc1T nc (87) for some a, b R. To conclude our proof, we need to show that a = nc nc 1 and b = 1 nc 1. (88) Combining (87) with the first equality of (85), we obtain 0 = F T F 1nc = a 1nc + b 1nc1T nc1nc = (a + bnc)1nc (89) Combining (87) with the fact that F 2 F = nc, we obtain nc = F 2 F = Tr(F T F) = nc(a + b) (90) The constants a, b R, according to (89) and (90) must therefore solve the system a + bnc = 0 a + b = 1 and one can easily check that the solution of this system is precisely given by (88). Published as a conference paper at ICLR 2024 We conlude this subsection by proving proposition A. Proof of Proposition A. Let (W , U ) be a global minimizer of R and let c R be such that 1 2 W 2 F + U 2 F = c2 nw If the latent variables satisfies assumption A, we can use lemma I together with the first statement of lemma H to obtain ΩI c E N BI c arg min R, which is precisely statement (i) of the proposition. We now prove statement (ii) of the proposition. If the latent variables satisfies assumption A and B then lemma J asserts that ΩI c = E N BI c The set ΩI c is clearly not empty (because the set of equiangular matrices is not empty), and we may therefore use the second statement of lemma H to obtain that (W , U ) E N BI c = ΩI c E.3 DETERMINING THE CONSTANT c The next lemma provides an explicit formula for the regularized risk of a network whose weights are in type-I collapse configuration with constant c. Lemma K. Assume the latent variables satisfy assumption A. If the pair of weights (W, U) belongs to ΩI c, then R(W, U) = log 1 K 1 + (nc 1)e η nwc2 L + λ nwc2 (91) where η = nc nc 1 q From the above lemma it is clear that if the pair (W, U) ΩI c minimizes R, then the constant c must minimize the right hand side of (91). Therefore combining lemma K with proposition A concludes the proof of theorem E. Remark In the previous subsections, we only relied on relations (i), (ii) and (iii) of lemma B to prove collapse. Assumption A was never fully needed. In this section however, in order to determine the specific values of the constant involved in the collapse, we will need the actual combinatorial values provided by assumption A. The remainder of this section is devoted to the proof of lemma K. Proof of lemma K. Recall from (45) that the unregularized risk can be expressed as R0(W, U) = 1 j =k e MW,U(x,j) We also recall that the set ΩI c is given by ΩI c = n (W, U) : There exist an equiangular matrix F such that W = c F P and ˆU = c r nw KL F Z o (92) and that Pζ(x) = Zk for all x Xk (see equation (25) from section B). Consider two latent variables zk = [α1, . . . , αL] and zj = [α 1, . . . , α L] Published as a conference paper at ICLR 2024 and assume x is generated by zk, meaning that x Xk. MW,U(x, j) = D ˆUk ˆUj, Wζ(x) E KL F Zk F Zj, F Pζ(x) F KL F Zk F Zj, F Zk F D fαℓ fα ℓ, fαℓ E D fα ℓ, fαℓ E Since f1, . . . , fnc are equiangular, we have D fα ℓ, fαℓ E F = L dist(zj, zk) 1 nc 1dist(zj, zk) = L nc nc 1dist(zj, zk). MW,U(x, j) = c2 r nw KL nc nc 1dist(zj, zk) Letting ω = p nw KL nc nc 1 we therefore obtain R0(W, U) = 1 j =k e ωc2dist(zj,zk) j =k e ωc2dist(zj,zk) where we have used the quantity inside the log does not depends on x. We proved in section C (see equation (38)) that if the latent variables satisfy assumption A, then Using this identity we obtain X j =k e ωc2dist(zj,zk) = r=1 |{j : dist(zj, zk) = r}| e ωc2r (nc 1)r e ωc2r (nc 1)r e ωc2r 1 + (nc 1)e ωc2 L where we have used the binomial theorem to obtain the last equality. The above quantity does not depends on k, therefore (93) can be expressed as R0(W, U) = log 1 K 1 + (nc 1)e ω c2 L We then remark that the matrix F P has nw columns, and that each of these columns has norm 1. Similarly, the F Z has KL columns of length 1. We therefore have 1 2 W 2 F + ˆU 2 F = 1 c2 F P 2 F + c2 nw KL F Z 2 F = c2nw. To conclude the proof we simply remark that ω = nwη. Published as a conference paper at ICLR 2024 F PROOF OF THEOREM 3 AND ITS CONVERSE In this section we prove theorem 3 under assumption A, and its converse under assumptions A and B. We start by recalling the definition of a type-II collapse configuration. Definition E (Type-II Collapse). The weights (W, U) of the network h W,U form a type-II collapse configuration if and only if the conditions i) φ(w(α,β)) = d fα for all (α, β) V. ii) There exists c 0 so that uk,ℓ= c fα for all (k, ℓ) satisfying zk,ℓ= α and all α C. hold for some collection f1, . . . , fnc Rd of mean-zero equiangular vectors. As in the previous section we will reformulate the above definition using matrix notations. Toward this aim we make the following definition: Definition F. (Mean-Zero Equiangular Matrices) A matrix F Rd nc is said to be a mean-zero equiangular matrix if and only if the relations 1T d F = 0, F 1nc = 0 and FT F = nc nc 1 Inc 1 nc 1 1nc1T nc Comparing the above definition with the definition of equiangular vectors provided in the main paper, we easily see that F is a mean-zero equiangular matrix if and only if its columns are meanzero equiangular vectors. Relations (i) and (ii) of definition F can be conveniently expressed as d F P and ˆU = c F Z for some equiangular matrix F. We then set ΩII c = n (W, U) : There exist a mean-zero equiangular matrix F such that d F P and ˆU = c F Z o (94) and note that ΩII c is simply the set of weights (W, U) which are in a type-II collapse configuration. We now state the main theorem of this section. Theorem F. Assume the non-degenerate condition µβ > 0 holds. Let τ 0 denote the unique minimizer of the strictly convex function H (t) = log 1 K 1 + (nc 1)e η t L + λ 2 t2 where η = nc nc 1 1 p and let c = τ/ KL. Then we have the following: (i) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumption A, then ΩII c arg min R (ii) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumptions A and B, then ΩII c = arg min R Note that statement (i) corresponds to theorem 3 of the main paper, whereas statement (ii) can be viewed as its converse. To prove F we will follow the same steps than in the previous section. The main difference occurs in the study of the bilinear problem, as we will see in the next subsection. We will assume µβ > 0 everywhere in this section all lemmas and propositions are proven under this assumption, even when not explicitly stated. Before to go deeper in our study let us state a very simple lemma that expresses the regularized risk R associated with network h in term of the function R0 defined by equation (44). Published as a conference paper at ICLR 2024 Lemma L. Given a pair of weights (W, U), we have R (W, U) = R0 φ(W) , U + λ 2 U 2 F (95) Proof. Recall from section B that h W,U(x) = U Vec [Wζ(x)] h W,U(x) = U Vec h φ Wζ(x) i Note that since ζ(α, β) is a one hot vector, we obviously have that φ (Wζ(α, β)) = φ (W) ζ(α, β). Therefore the the network h and h are related as follow: h W,U(x) = U Vec h φ Wζ(x) i = U Vec h φ(W) ζ(x) i = hφ(W ), U(x) As a consequence, the regularized risk associated with the network h W,U can be expressed as R (W, U) = 1 h ℓ(h W,U(x), k) i + λ h ℓ(hφ(W ),U(x), k) i + λ = R0( φ(W) , U) + λ where R0 is the unregularized risk defined in (44). F.1 THE BILINEAR OPTIMIZATION PROBLEM Let Range(φ) = {V Rd nw : There exist W Rd nw such that V = φ(W) } and consider the optimization problem maximize D ˆU, V QT Z E subject to V Range(φ) and ˆU 2 F = KL c2 (97) where the optimization variables are the matrix V Rd nw and the matrix ˆU Rd KL. Lemma M. Assume the latent variables satisfy assumption A. Then (V, U) is a solution of the optimization problem (96) (97) if and only if it belongs to the set BII c = n (V, U) : There exist a matrix F F such that V = d FP and ˆU = c FZ o (98) where F denotes the set of matrices whose columns have unit length and mean zero, that is F = {F Rd nc : 1T d F = 0 and the columns of F have unit length}. The remainder of this subsection is devoted to the proof of the above lemma. We start by showing that all (V, U) BII c have same objective values and satisfy the constraints. Claim D. If (V, U) BII c , then V Range(φ) , ˆU 2 F = KL c2, and D ˆU, V QT Z E Proof. Assume (V, U) BII c . Since the columns of P are one hot vectors in Rnc, the columns of FP have unit length and mean zero. Therefore the columns of V have norm equal to d and mean zero. Therefore V Range(φ). Published as a conference paper at ICLR 2024 Using ZZT = KL nc I from lemma B, together with the fact that F 2 F = nc since its columns have unit length, we obtain FZ 2 F = FZ, FZ F = FZZT , F F 2 F = KL (99) As a consequence we have ˆU 2 F = c2 KL. Finally, note that PQT = Inc as can clearly be seen from formulas (23) and (26). We therefore have D ˆU, V QT Z E d FZ, FPQT Z d FZ, FZ F = c We then prove that Claim E. If (V, ˆU) is a solution of (96) (97), then (V, ˆU) BII c . Note that according to the first claim, all (V, ˆU) BII c have same objective value, and therefore, according to the above claim, they must all be maximizer. As a consequence, proving the above claim will conclude the proof of lemma M. Proof of the claim. Maximizing (96) (97) over ˆU first gives KL V QT Z V QT Z F (100) and therefore the optimization problem reduces to maximize V QT Z 2 F (101) subject to V Range(φ) (102) Using the fact that ZZT = KL nc I we then get V QT Z 2 F = V QT Z, V QT Z F = V QT ZZT , V QT nc V QT 2 F (103) and so the problem further reduces to maximize V QT 2 F (104) subject to V Range(φ) (105) Let us define v(α,β) := V ζ(α, β) In other words v(α,β) is the jth column of V , where j = (α 1)sc + β. The KKT conditions for the optimization problem (104) (105) then amount to solving the system V QT Q = V Dν + 1d λT (106) v(α,β), 1d = 0 for all (α, β) V (107) v(α,β) 2 = d for all (α, β) V (108) for Dν some nw nw diagonal matrix of Lagrange multipliers for the constraint (108) and λ Rnw a vector of Lagrange multipliers for the mean zero constraints. Left multiplying the first equation by 1T d and using the second shows λ = 0nw, and so it proves equivalent to find solutions of the reduced system V QT Q = V Dν (109) v(α,β), 1d = 0 for all (α, β) V (110) v(α,β) 2 = d for all (α, β) V (111) Published as a conference paper at ICLR 2024 instead. Recalling the identity Q ζ(α, β) = µβeα (see (27) in section B) we obtain QT Q ζ(α, β) = µβ QT eα and so right multiplying (109) by ζ(α, β) gives V QT eα = ν(α, β) µβ v(α, β) for all (α, β) V where we have denoted by ν(α, β) the Lagrange multiplier corresponding to the constraint (111). Define the support sets Ξα := {β [sc] : ν(α, β) = 0} and Ξ := {α : Ξα = } of the Lagrange multipliers. If α Ξ then imposing the norm constraint (111) gives V QT eα = ν(α, β) and so V QT eα > 0 if α Ξ since ν(α, β) > 0 for some β [sc] by definition. This implies that the relation d V QT eα V QT eα for all (α, β) Ξ [sc] must hold. As a consequence there exist mean-zero, unit length vectors f1, . . . , fnc (namely the normalized V QT eα) so that v(α, β) = d fα holds for all pairs (α, β) with α Ξ. Taking a look at (26), we easily see that its αth row of the matrix Q can be written as QT eα = P β µβζ(α, β), and therefore V QT eα = X β [sc] µβV ζ(α, β) = X β [sc] µβv(α, β) = holds as well. If α / Ξ then V QT eα = 0 since the corresponding Lagrange multiplier vanishes. It therefore follows that V QT 2 F = X α [nc] V QT eα 2 = d X α Ξ fα 2 = d |Ξ| and so global maximizers of (104) (105) must have full support. In other words, there exist meanzero, unit-length vectors f1, . . . , fnc so that holds. Equivalently V = d FP for some F F. We then recover ˆU using (100). KL V QT Z V QT Z F = c KL FPQT Z FPQT Z F = c KL FZ FZ F (113) where we have used the fact that PQT = Inc. To conclude the proof, we use the fact FZ F = KL, as was shown in (99). F.2 PROOF OF COLLAPSE Recall from lemma L that the regularized risk associated with the network h W,U can be expressed as R (W, U) = R0 φ(W) , U + λ 2 U 2 F (114) and recall that the set of weights in type-II collapse configuration is ΩII c = n (W, U) : There exist a mean-zero equiangular matrix F such that d F P and ˆU = c F Z o (115) This subsection is devoted to the proof of the following proposition. Published as a conference paper at ICLR 2024 Proposition B. We have the following: (i) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumption A, then there exists c R such that ΩII c arg min R (ii) If the latent variables z1, . . . , z K are mutually distinct and satisfy assumptions A and B, then any (W, U) that minimizes R must belong to ΩII c for some c R. As in the previous section, we have the following lemma. Lemma N. Any global minimizer of (114) must belong to N. The proof is identical to the proof of lemma G. The next lemma bring together the bilinear optimization problem from subsection F.1 and the sharp lower bound on the unregularized risk that we derived in section D. Lemma O. Assume the latent variables satisfy assumption A. Assume also that (W , U ) is a global minimizer of (114) and let c R be such that U 2 F = KL c2 The the following hold: (i) Any (W, U) that satisfies (φ(W), U) N E BII c is also a global minimizer of R . (ii) If N E BII c = , then (φ(W ), U ) N E BII c Proof. Recall from theorem D that R0(V, U) = g D ˆU, V QT Z E for all (V, U) N E (116) R0(V, U) > g D ˆU, V QT Z E for all (V, U) N Ec (117) We start by proving (i). Define V = φ(W ), and assume that U, V, W are such that φ(W) = V and (V, U) N E Bc. Then we have R0(φ(W ), U ) = R0(V , U ) g ( U , V QZ F ) [because (V , U ) N ] g ( U, V QZ F ) [because (V, U) BII c ] = R0(V, U) [because (V, U) N E ] = R0(φ(W), U) Since U 2 F = KL c2 = U 2 F , we have R (W, U) R (W , U ) and therefore (W, U) is a minimizer. We now prove (ii) by contradiction. Suppose that (φ(W ), U ) / N E BII c . This must mean that (φ(W ), U ) / E BII c since it clearly belongs to N. If (φ(W ), U ) / E then the first inequality in the above computation is strict according to (117). If (φ(W ), U ) / BII c then the second inequality is strict because g is strictly increasing. The next two lemmas shows that the set E N BII c is closely related to the set of collapsed configurations ΩII c . In order to states these lemmas, the following definition will prove convenient Ω II c = n (V, U) : There exist a mean-zero equiangular matrix F such that d F P and ˆU = c F Z o (118) Published as a conference paper at ICLR 2024 Note that (W, U) ΩII c if and only if (φ(W), U) Ω II c . Also, in light of (98), the inclusion Ω II c BII c is obvious. We now prove the following lemma. Lemma P. If the latent variables satisfy the symmetry assumption A, then Ω II c E N BII c Proof. The proof is almost identical to the one of lemma I. We repeat it for completeness. We already know that Ω II c BII c . We the show that Ω II c E. Suppose (V, U) Ω II c . Then there exists a mean-zero equiangular matrix F Rd nc such that d F P and ˆU = c F Z Recall from (25) that Pζ(x) = Zk for all x Xk. Consider two latent variables zk = [α1, . . . , αL] and zj = [α 1, . . . , α L] and assume x is generated by zk, meaning that x Xk. We then have MV,U(x, j) = D ˆUk ˆUj, V ζ(x) E d F Zk F Zj, F Pζ(x) F = c d F Zk F Zj, F Zk F D fαℓ fα ℓ, fαℓ E d dist(zj, zk) From the above computation it is clear that the margin only depends on dist(zj, zk), and therefore (V, U) satisfies the equimargin property. Finally we show that Ω II c N. Suppose (V, U) Ω II c . Using the identity PK k=1 Zk = K nc 1nc1T L we obtain k=1 ˆUk = c k=1 F Zk = c K nc F 1nc1T L = 0 where we have used the fact that F 1nc = 0. Finally, we have the following lemma. Lemma Q. If the latent variables satisfy assumptions A and B, then Ω II c = E N BII c Proof. The proof, again, is very similar to the one of lemma J. From the previous lemma we know that Ω II c E N BII c so we need to show that E N BII c Ω II c . Let (V, U) E N BII c . Since (V, U) belongs to BII c , there exists a matrix F Rd nc whose columns have unit length and mean 0 such that d F P and U = c F Z Our goal is to show that F is a mean-zero equiangular matrix, meaning that it satisfies the three relations 1T nc F = 0, F 1nc = 0 and F T F = nc nc 1 Inc 1 nc 1 1nc1T nc. (119) Published as a conference paper at ICLR 2024 We already know that the first relation is satisfied since the columns of F have mean 0. The second relation is easily obtained. Indeed, using the fact that (V, U) N together with the identity PK k=1 Zk = K nc 1nc1T L (which hold due to lemma B), we obtain k=0 Uk = c K X k=0 FZk = c K nc F1nc1T L. which implies F1nc = 0. We now prove the third equality of (119). Assume that x Xk. Using the fact that Pζ(x) = Zk together with (84), we obtain MV,U(x, j) = D ˆUk ˆUj, V ζ(x) E d F Zk F Zj, F Pζ(x) F = c d F Zk F Zj, F Zk F = c d F T F(Zk Zj), Zk d D F T F , Γ(k,j) E Since (V, U) E, the margins must only depend on the distance between the latent variables. Due to (120), we can be express this as D F T F , Γ(j,k)E F = D F T F , Γ(j ,k )E F j, k, j , k [K] s.t. dist(zj, zk) = dist(zj , zk ) Since the F T F is clearly positive semi-definite, we may then use assumption B to conclude that F T F A. Recalling definition (42) of the set A, we therefore have F T F = a Inc + b 1nc1T nc (121) for some a, b R. To conclude our proof, we need to show that a = nc nc 1 and b = 1 nc 1. (122) Combining (121) with the first equality of (119), we obtain 0 = F T F 1nc = a 1nc + b 1nc1T nc1nc = (a + bnc)1nc Since the columns of F have unit length, the diagonal entries of F T F must all be equal to 1, and therefore (121) implies that a + b = 1. The constants a, b R, according must therefore solve the system a + bnc = 0 a + b = 1 and one can easily check that the solution of this system is precisely given by (122). We conlude this subsection by proving proposition B. Proof of Proposition B. Let (W , U ) be a global minimizer of R and let c R be such that U 2 F = KL c2 We first prove statement (i) of the proposition. If the latent variables satisfies assumption A then lemma P asserts that Ω II c E N BII c Assume (W, U) ΩII c . This implies that (φ(W), U) Ω II c , and and therefore (φ(W), U) E N BII c . We can then use lemma O to conclude that (W, U) is a global minimizer of R . Published as a conference paper at ICLR 2024 We now prove statement (ii) of the proposition. If the latent variables satisfies assumption A and B then lemma Q asserts that Ω II c = E N BII c The set Ω II c is clearly not empty (because the set of mean-zero equiangular matrices is not empty), and we may therefore use the second statement of lemma O to obtain that (φ(W ), U ) E N BII c = Ω II c which in turn implies (W , U ) ΩII c . F.3 DETERMINING THE CONSTANT c The next lemma provides an explicit formula for the regularized risk of a network h W,U whose weights are in type-II collapse configuration with constant c. Lemma R. Assume the latent variables satisfy assumption A. If the pair of weights (W, U) belongs to ΩII c , then R (W, U) = log 1 K 1 + (nc 1)e η KL c 2 (123) where η = nc nc 1 Combining lemma R with proposition B concludes the proof of theorem F. Proof of lemma R. We recall that R0(W, U) = 1 j =k e MW,U(x,j) ΩII c = n (W, U) : There exist a mean-zero equiangular matrix F such that d F P and ˆU = c F Z o (124) Consider two latent variables zk = [α1, . . . , αL] and zj = [α 1, . . . , α L] and assume x Xk. Using the identity Pζ(x) = Zk we then obtain Mφ(W ),U(x, j) = D ˆUk ˆUj, φ(W)ζ(x) E d F Zk F Zj, F Pζ(x) F = c d F Zk F Zj, F Zk F D fαℓ fα ℓ, fαℓ E D fα ℓ, fαℓ E d nc nc 1dist(zj, zk) Letting ω = d nc nc 1 we therefore obtain R0(W, U) = 1 j =k e c ω dist(zj,zk) j =k e c ω dist(zj,zk) Published as a conference paper at ICLR 2024 where we have used the quantity inside the log does not depends on x. Using the identity |Sr| = K n L c L r (nc 1)r we then obtain obtain j =k e c ω dist(zj,zk) = r=1 |{j : dist(zj, zk) = r}| e c ω r (nc 1)r e c ω r (nc 1)r e c ω r 1 + (nc 1)e c ω L where we have used the binomial theorem to obtain the last equality. The above quantity does not depends on k, therefore (125) can be expressed as R0(W, U) = log 1 K 1 + (nc 1)e c ω L We then remark that the matrix F Z has KL columns, and that each of these columns has norm 1. We therefore have ˆU 2 F = c F Z 2 F = c2KL for all (W, U) ΩII c To conclude the proof we simply note that ω = G PROOF OF THEOREM 2 This section is devoted to the proof of theorem 2 from the main paper, which we recall below for convenience. Theorem 2 (Directional Collapse of h). Assume K = n L c and {z1, . . . , z K} = Z. Assume also that the regularization parameter λ satisfies λ2 < L n L+1 c β=1 µ2 β (126) Finally, assume that (W, U) is in a type-III collapse configuration for some constants c, r1, . . . , rsc 0. Then (W, U) is a critical point of R if and only if (c, r1, . . . , rsc) solve the system λ L rβ nc 1 + exp nc nc 1c rβ = µβ for all 1 β sc (127) 2 = Ln L 1 c . (128) At the end of this section, we also show that if (149) holds, then the system (150) (151) has a unique solution (see proposition D in subsection G.2). The strategy to prove theorem 2 is straightforward: we simply need to evaluate the gradient of the risk on weights (W, U) which are in a type-III collapse configuration. Setting this gradient to zero will then lead to a system for the constants c, r1, . . . , rsc defining the configuration. While conceptually simple, the gradient computation is quite lengthy. We start by deriving formulas for the partial derivatives of R0 with respect to the linear weights uk,ℓ and the word embeddings w(α,β). As we will see, R0/ uk,ℓand R0/ w(α,β) plays symmetric roles. In order to observe this symmetry, the following notation will prove convenient: Φ(α,β),(k,ℓ)(W, U) := 1 x Xj 1{xℓ=(α,β)} 1{j=k} qk,W,U(x) Dzj(x) (129) Published as a conference paper at ICLR 2024 qk,W,U(x) := e ˆUk,W ζ(x) F PK k =1 e Uk ,W ζ(x) F We may now state the first lemma of this section: Lemma S. The partial derivatives of R0 with respect to uk,ℓand w(α,β) are given by uk,ℓ (W, U) = β=1 Φ(α,β),(k,ℓ)(W, U) w(α,β) R0 w(α,β) (W, U) = ℓ=1 Φ(α,β),(k,ℓ)(W, U) uk,ℓ Proof. Given K matrices V1, . . . , VK Rnw KL, we define f(V1, . . . , VK) := 1 x Xk ℓ V1, ζ(x) F , . . . , VK, ζ(x) F ; k Dzk(x) where ℓ(y1, . . . , y K; k) is the cross entropy loss ℓ(y1, . . . , y K; k) = log exp (yk) PK k =1 exp (yk ) The partial derivative of f with respect to the matrix Vj can easily be found to be Vj (V1, . . . , VK) = 1 1{j=k} e Vj,ζ(x) F PK k =1 e Vk ,ζ(x) F ζ(x) Dzk(x) (130) We then recall from (19) that the kth entry of the vector y = h W,U(x) is yk = D ˆUk , W ζ(x) E F = D W T ˆUk , ζ(x) E F and so the unregularized risk can be expressed in term of the function f: R0(W, U) = 1 x Xk ℓ W T ˆU1, ζ(x) F , . . . , W T ˆUK, ζ(x) F ; k Dzk(x) = f(W T ˆU1, . . . , W T ˆUK) The chain rule then gives Vj (W T ˆU1, . . . , W T ˆUK) T (131) ˆUj (W, U) = W f Vj (W T ˆU1, . . . , W T ˆUK) (132) Using formula (130) for f/ Vj and the notation qj,W,U(x) := e W T ˆUj,ζ(x) F PK k =1 e W T Uk ,ζ(x) F we can express (131) and (132) as follow 1{j=k} qj,W,U(x) ζ(x) Dzk(x) ˆUj (W, U) = W 1{j=k} qj,W,U(x) ζ(x) Dzk(x) Published as a conference paper at ICLR 2024 We now compute the partial derivative of R0 with respect to uj,ℓ. Let eℓ RL be the ℓth basis vector. We then have uj,ℓ (W, U) = 1{j=k} qj,W,U(x) (Wζ(x) eℓ) Dzk(x) Recall from (13) that Wζ(x) is the matrix that contains the d-dimensional embeddings of the words that constitute the sentence x X. So Wζ(x) eℓis simply the embedding of the ℓth word of the sentence x, and we can write it as β=1 1{xℓ=(α,β)}w(α,β) We therefore have uj,ℓ (W, U) = 1 1{j=k} qj,W,U(x) β=1 1{xℓ=(α,β)}w(α,β) 1{j=k} qj,W,U(x) 1{xℓ=(α,β)}Dzk(x) β=1 Φ(α,β),(j,ℓ)(W, U) w(α,β) which is the desired formula. We now compute the gradient with respect w(α,β). Recalling that ζ(α, β) is the one hot vector associate with word (α, β), we have R0 w(α,β) (W, U) = R0 W (W, U) ζ(α, β) 1{j=k} qj,W,U(x) ˆUj ζ(x)T ζ(α, β) Dzk(x) Recall that the ℓth column of ζ(x) is the one-hot encoding of the ℓth word in the sentence x. Therefore, the ℓth entry of the vector ζ(x)T ζ(α, β) RL is given by the formula ζ(x)T ζ(α, β) ℓ= 1 if xℓ= (α, β) 0 otherwise As a consequence ˆUj ζ(x)T ζ(α, β) = ℓ=1 1{xℓ=(α,β)}uj,ℓ which leads to R0 w(α,β) (W, U) = 1 1{j=k} qj,W,U(x) L X ℓ=1 1{xℓ=(α,β)}uj,ℓ 1{j=k} qj,W,U(x) 1{xℓ=(α,β)} Dzk(x) uj,ℓ j=1 Φ(α,β),(j,ℓ)(W, U) uj,ℓ which is the desired formula. Published as a conference paper at ICLR 2024 G.1 GRADIENT OF THE RISK FOR WEIGHTS IN TYPE-III COLLAPSE CONFIGURATION In lemma S we computed the gradient of the risk for any possible weights (W, U) and for any possible latent variables z1, . . . , z K. In this section we will derive a formula for the gradient when the weights are in type-III collapse configuration and when the latent variables satisfy {z1, . . . , z K} = Z. We start by recalling the definition of a type-III collapse configuration. Definition G (Type-III Collapse). The weights (W, U) of the network h W,U form a type-III collapse configuration if and only if i) There exists positive scalars rβ 0 so that w(α, β) = rβ fα for all (α, β) V. ii) There exists c 0 so that uk,ℓ= c fα for all (k, ℓ) satisfying zk,ℓ= α and all α C. hold for some collection f1, . . . , fnc Rd of equiangular vectors. We also define the constant γ R and the sigmoid σ : R R as follow: γ := 1 nc 1 and σ(x) := 1 1 + γe(1+γ)x (133) The goal of this subsection is to prove the following proposition. Proposition C. Suppose K = n L c and {z1, . . . , z K} = Z. If the weights (W, U) are in a type-III collapse configuration with constants c, r1, . . . , rsc 0, then uk,ℓ (W, U) = 1 β=1 µβ σ(c rβ) rβ w(α,β) (W, U) = c L(1 + γ) Importantly, note that the above proposition states that R0/ uk,ℓand uk,ℓare aligned, and that R0/ w(α,β) and w(α,β) are aligned. We start by introducing some notations which will make these gradient computations easier. The latent variables z1, . . . , z K will be written as zk = [ zk,1 , zk,2 , . . . , zk,L ] Z where 1 zk,ℓ nc. We remark that any sentence x generated by the latent variable zk must be of the form x = [(zk,1, β1), . . . , (zk,L, βL)] for some (β1, . . . , βL) [nc]L, and that this sentence has a probability µβ1µβ2 µβL of being sampled. In light of this, we make the following definitions. For every β = (β1, . . . , βL) [n L c ] we let xk,β := [(zk,1, β1), . . . , (zk,L, βL)] X (134) µ[β] := µ[β1] µ[β2] µ[βL] [0, 1] (135) where we have used µ[βℓ] instead of µβℓin order to avoid the double subscript. With these definitions at hand we have that Dzj(xk,β) = µ[β] if k = j 0 otherwise We are now ready to prove proposition C. We break the computation into four lemmas. The first one simply uses the notations that we just introduced in order to express Φ(α,β),(k,ℓ) in a more convenient format. Lemma T. The quantity Φ(α ,β ),(k,ℓ)(W, U) can be expressed as Φ(α ,β ),(k,ℓ)(W, U) = 1 β [n L c ] 1{βℓ=β } j=1 1{zj,ℓ=α } qk,W,U(xj,β) Published as a conference paper at ICLR 2024 Proof. Using the above notations, we rewrite Φ(α,β),(k,ℓ)(W, U) as follow: Φ(α ,β ),(k,ℓ)(W, U) = 1 x Xj 1{xℓ=(α ,β )} 1{j=k} qk,W,U(x) Dzj(x) β [n L c ] 1{(zj,ℓ,βℓ)=(α ,β )} 1{j=k} qk,W,U(xj,β) Dzj(xj,β) β [n L c ] 1{zj,ℓ=α }1{βℓ=β } 1{j=k} qk,W,U(xj,β) µ[β] β [n L c ] 1{βℓ=β } j=1 1{zj,ℓ=α } 1{j=k} qk,W,U(xj,β) To conclude the proof we simply remark that P j 1{zj,ℓ=α }1{j=k} = 1{zk,ℓ=α }. The following notation will be needed in our next lemma: δ(α, α ) = 1 if α = α γ if α = α for all α, α [nc] (136) where we recall that γ = 1/(nc 1). We think of δ(α, α ) as a biased Kroecker delta on the concepts. Importantly, note that if f1, . . . , fnc are equiangular, then fα, fα = δ(α, α ) which is the motivation behind this definition. We may now state our second lemma. Lemma U. Assume K = n L c and {z1, . . . , z K} = Z. Assume also that the weights (W, U) are in a type-III collapse configuration with constants c, r1, . . . , rsc 0. Then qk,W,U(xj,β) = QL ℓ=1 exp c rβℓδ(zj,ℓ, zk,ℓ) QL ℓ=1 ψ(c rβℓ) where ψ(x) = ex + 1 for all j, k [K] and all β = (β1, . . . , βL) [nc]L. Proof. Recalling that xj,β := [(zj,1, β1), . . . , (zj,L, βL)], we obtain D ˆUk, Wζ(xj,β) E ℓ=1 uk,ℓ, w(zj,ℓ,βℓ) = ℓ=1 c fzk,ℓ, rβℓfzj,ℓ = c ℓ=1 rβℓδ(zk,ℓ, zj,ℓ) We then have qk,W,U(xj,β) = e ˆUk,W ζ(xj,β) F PK k =1 e ˆUk ,W ζ(xj,β) F = exp c PL ℓ=1 rβℓδ(zk,ℓ, zj,ℓ) PK k =1 exp c PL ℓ=1 rβℓδ(zk ,ℓ, zj,ℓ) QL ℓ=1 exp c rβℓδ(zk,ℓ, zj,ℓ) PK k =1 QL ℓ=1 exp c rβℓδ(zk ,ℓ, zj,ℓ) Since {z1, . . . , z K} = Z, the latent variables zk = [zk ,1, . . . , zk ,L] achieve all possible tuples [α 1, , α L] [nc]L. The bottom term can therefore be expressed as ℓ=1 exp c rβℓδ(zk ,ℓ, zj,ℓ) α L=1 exp c rβ1δ(α 1, zj,1) exp c rβ2δ(α 2, zj,2) exp c rβLδ(α L, zj,L) α ℓ=1 exp c rβℓδ(α ℓ, zk,ℓ) Published as a conference paper at ICLR 2024 Recalling the definition of δ(α, α ), we find that α ℓ=1 exp c rβℓδ(α ℓ, zk,ℓ) = exp(c rβℓ) + X α ℓ =zk,ℓ exp c rβℓ = exp(c rβℓ) + (nc 1) exp c rβℓ = ψ(c rβℓ) (137) We now find a convenient expression for the term appearing between parenthesis in the statement of lemma T. Lemma V. Assume K = n L c and {z1, . . . , z K} = Z. Assume also that the weights (W, U) are in a type-III collapse configuration with constants c, r1, . . . , rsc 0. Then j=1 1{zj,ℓ=α } qk,W,U(xj,β) = δ(zk,ℓ, α ) σ(c rβℓ) (138) for all k [K], ℓ [L], α [nc] and all β = (β1, . . . , βL) [nc]L. Proof. For simplicity we are going to prove equation (138) in the case ℓ= 1. Using the previous lemma we obtain j=1 1{zj,1=α }qk,W,U(xj,β) = j=1 1{zj,1=α } QL ℓ=1 exp c rβℓδ(zj,ℓ, zk,ℓ) QL ℓ=1 ψ(c rβℓ) Since the latent variables zj = [zj,1, . . . , zj,L] achieve all possible tuples [α1, , αL] [nc]L, we can rewrite the above as αL=1 1{α1=α } QL ℓ=1 exp c rβℓδ(αℓ, zk,ℓ) QL ℓ=1 ψ(c rβℓ) exp c rβ1δ(α , zk,1) QL ℓ=2 exp c rβℓδ(αℓ, zk,ℓ) QL ℓ=1 ψ(c rβℓ) = exp c rβ1δ(α , zk,1) QL ℓ=1 ψ(c rβℓ) ℓ=2 exp c rβℓδ(αℓ, zk,ℓ) (139) We then note that ℓ=2 exp c rβℓδ(αℓ, zk,ℓ) = α ℓ=1 exp c rβℓδ(αℓ, zk,ℓ) and, repeating computation (137), we find that αℓ=1 exp (c rβℓδ(αℓ, zk,ℓ)) = ψ(c rβℓ) Going back to (139) we therefore have j=1 1{zj,1=α }qk,W,U(xj,β) = exp c rβ1δ(α , zk,1) QL ℓ=1 ψ(c rβℓ) ℓ=2 ψ(c rβℓ) = exp c rβ1δ(α , zk,1) j=1 1{zj,1=α } qk,W,U(xj,β) = 1 exp(c rβ1) ψ(c rβ1) if zk,1 = α exp( γ c rβ1) ψ(c rβ1) if zk,1 = α (140) Published as a conference paper at ICLR 2024 We now manipulate the above formula. Recalling that γ = 1/(nc 1), and recalling the definition of ψ(x), we get ψ(x) = 1 ex γ e γx = 1 1 + γe(1+γ)x = σ(x) (141) ψ(x) = e γx γ e γx = γ 1 1 + γe(1+γ)x which concludes the proof. Our last lemma provides a formula for the quantity Φ(α ,β ),(k,ℓ)(W, U) when the weights are in a type-III collapse configuration. Lemma W. Assume K = n L c and {z1, . . . , z K} = Z. Assume also that the weights (W, U) are in a type-III collapse configuration with constants c, r1, . . . , rsc 0. Then Φ(α,β),(k,ℓ)(W, U) = µβ n Lc σ(c rβ) δ(zk,ℓ, α) (142) for all k [K], ℓ [L], α [nc] and β [sc]. Proof. Combining lemmas T and V, and recalling that K = n L c , we obtain Φ(α ,β ),(k,ℓ)(W, U) = 1 β [n L c ] 1{βℓ=β } j=1 1{zj,ℓ=α } qk,W,U(xj,β) β [n L c ] 1{βℓ=β } δ(zk,ℓ, α ) σ(c rβℓ) µ[β] = δ(zk,ℓ, α ) β [n L c ] 1{βℓ=β } σ(c rβℓ) µ[β] Choosing ℓ= 1 for simplicity we get β [n L c ] 1{β1=β } σ(c rβ1) µ[β] = βL=1 1{β1=β } σ(c rβ1) µ[β1]µ[β2] µ[βL] βL=1 σ(c rβ ) µ[β ]µ[β2] µ[βL] = µ[β ] σ(c rβ ) which concludes the proof. We now prove the proposition. Proof of proposition C. Combining lemmas S and W, and using the fact that w(α,β) = rβfα, we obtain uk,ℓ (W, U) = β=1 Φ(α,β),(k,ℓ)(W, U) w(α,β) n Lc σ(c rβ) δ(zk,ℓ, α) rβfα β=1 µβ σ(c rβ) rβ α=1 δ(zk,ℓ, α) fα Published as a conference paper at ICLR 2024 Using the fact that Pnc α=1 fα = 0 we get α=1 δ(zk,ℓ, α) fα = fzk,ℓ γ X α =zk,ℓ fα = fzk,ℓ+ γ fzk,ℓ γ α=1 fα = (1 + γ) fzk,ℓ (143) Using the fact that uk,ℓ= c fzk,ℓwe then get uk,ℓ (W, U) = 1 β=1 µβ σ(c rβ) rβ (1 + γ) fzk,ℓ β=1 µβ σ(c rβ) rβ which is the desired formula. Moving to the other gradient we get R0 w(α,β) (W, U) = ℓ=1 Φ(α,β),(k,ℓ)(W, U) uk,ℓ n Lc σ(c rβ) δ(zk,ℓ, α) c fzk,ℓ n Lc σ(c rβ) c k=1 δ(zk,ℓ, α) fzk,ℓ Since the latent variables zk = [zk,1, . . . , zk,L] achieve all possible tuples [α 1, , α L] [nc]L, we have, fixing ℓ= 1 for simplicity, k=1 δ(zk,1, α) fzk,1 = α L=1 δ(α 1, α) fα 1 = n L 1 c α 1=1 δ(α 1, α) fα 1 (144) Repeating computation (143) shows that the above is equal to n L 1 c (1 + γ) fα. We then use the fact that w(α,β) = rβfα to obtain R0 w(α,β) (W, U) = µβ n Lc σ(c rβ) c L n L 1 c (1 + γ) fα n Lc σ(c rβ) c L n L 1 c (1 + γ) w(α,β) nc σ(c rβ) c L (1 + γ) w(α,β) which is the desired formula. G.2 PROOF OF THE THEOREM AND STUDY OF THE NON-LINEAR SYSTEM In this subsection we start by proving theorem 2, and then we show that the system (150) (151) has a unique solution if the regularization parameter λ is small enough. Proof of theorem 2. Recall that the regularized risk associated with the network h W,U is defined by R(W, U) = R0(W, U) + λ 2 W 2 F + U 2 F (145) = R0(W, U) + λ β=1 w(α,β) 2 + Published as a conference paper at ICLR 2024 and therefore (W, U) is a critical points if and only if uk,ℓ (W, U) = λ uk,ℓ and R0 w(α,β) (W, U) = λ w(α,β) According to proposition C, if (W,U) is in a type-III collapse configuration, then the above equations becomes β=1 µβ σ(c rβ) rβ uk,ℓ= λ uk,ℓ and c L(1 + γ) rβ w(α,β) = λ w(α,β) So (W, U) is critical if and only if the constants r1, . . . , rsc and c satisfy the sc + 1 equations β=1 µβ σ(c rβ)rβ = λ (147) rβ = λ for all β [sc] (148) From the second equation we have that (1 + γ) µβ σ(c rβ) rβ = nc λ r2 β L c Using this we can rewrite the first equation as nc λ r2 β L c = λ which simplifies to 2 = Ln L 1 c . which is the desired equation (see (151)). We now rewrite the second equation as c nc (1 + γ) σ(c rβ) = µβ We then recall that σ(x) := 1 1+γe(1+γ)x and therefore nc (1 + γ) σ(crβ) = nc 1 + γ (1 + γe(1+γ)crβ) = nc 1 + exp nc nc 1crβ and therefore the second equation can be written as nc 1 + exp nc nc 1c rβ We now prove that if the regularization parameter λ is small enough then the system has a unique solution. Proposition D. Assume µ1 µ2 . . . µsc > 0 and λ2 < L n L+1 c β=1 µ2 β. (149) Then the system sc + 1 equations nc 1 + exp nc nc 1c rβ = µβ for all 1 β sc (150) 2 = Ln L 1 c (151) has a unique solution (c, r1, . . . , rsc) Rsc+1 + . Moreover this solution satisfies r1 r2 . . . rsc > 0. Published as a conference paper at ICLR 2024 Proof. Letting ρβ := rβ/c, the system is equivalent to g(c, ρβ) = L λnc µβ for all β [sc] (152) β=1 ρ2 β = Ln L 1 c (153) for the unknowns (c, ρ1, ρ2, . . . , ρsc) where g(c, x) = x 1 + γe(1+γ)c2x /(1 + γ) and γ = 1/(nc 1) Note that g x(c, x) 1 + γe(1+γ)c2x /(1 + γ) 1 (c, x) R [0, + ) and therefore x 7 g(c, x) is strictly increasing on [0, + ). Also note that we have g(c, 0) = 0, lim x + g(c, x) = + So x 7 g(c, x) is a bijection from [0, + ) to [0, + ) as well as a bijection from (0, + ) to (0, + ). Recall that µβ (0, + ) for all β [sc]. Therefore given c R and β [sc], the equation g(c, x) = L λnc µβ has a unique solution in (0, + ) that we denote by ϕβ(c). In other words, the function ϕβ(c) is implicitly defined by g(c, ϕβ(c)) = L λnc µβ. (154) Also, since g(0, x) = x, we have ϕβ(0) = L λnc µβ Claim F. The function ϕβ : [0, + ) (0, + ) is continuous, strictly decreasing, and satisfies limc + ϕβ(c) = 0. Proof. We first show that c 7 ϕβ(c) is continuous. Since g x(c, x) 1 for all x 0, we have g(c, x2) g(c, x1) = Z x2 g x(c, x)dx Z x2 x1 1dx = x2 x1 for all c and all x2 x1 0. As a consequence, for all c1, c2, we have |ϕβ(c2) ϕβ(c1)| |g(c1, ϕβ(c2)) g(c1, ϕβ(c1))| = |g(c1, ϕβ(c2)) g(c2, ϕβ(c2))| (155) where we have used the fact that g(c1, ϕβ(c1)) = L λnc µβ = g(c2, ϕβ(c2)). From (155) it is clear that the continuity of c 7 g(c, x) implies the continuity of c 7 ϕβ(c). We now prove that ϕβ is strictly decreasing on [0, + ). Let 0 c1 < c2. Note that for any x > 0, the function c 7 g(c, x) is strictly increasing on [0, + ). Since ϕβ(c) > 0 we therefore have g(c2, ϕβ(c2)) = L λnc µβ = g(c1, ϕβ(c1)) < g(c2, ϕβ(c1)) Since x 7 g(c, x) is strictly increasing for all c, the above implies that ϕβ(c2) < ϕβ(c1). Finally we show that limc + ϕβ(c) = 0. Since ϕβ is decreasing and non-negative on [0, + ), the limc + ϕβ(c) = A is well defined. We obviously have ϕβ(c) A for all c 0. Since x 7 g(c, x) is increasing we have L λnc µβ = g(c, ϕβ(c)) g(c, A) But the function g(c, A) is unbounded for all A > 0. Therefore we must have A = 0. Published as a conference paper at ICLR 2024 System (152) (153) is equivalent to ρβ = ϕβ(c) for all β [sc] (156) β=1 (ϕβ(c))2 = Ln L 1 c (157) Define the function β=1 (ϕβ(c))2 Then Φ clearly inherits the properties of the ϕβ s: it is continuous, strictly decreasing, and satisfies 2 and lim c + Φ(c) = 0 Therefore, if then there is a unique c 0 satisfying (157). Since x 7 g(c, x) is increasing, equation (152) implies that the corresponding ρβ s satisfy ρ1 ρ2 . . . ρsc > 0. H NO SPURIOUS LOCAL MINIMIZER FOR R(W, U). In this section we prove that if d > min(nw, KL), then R(W, U) does not have spurious local minimizers; all local minimizers are global. To do this, we introduce the function f : Rd KL R define as follow. Any matrix V Rd KL can be partition into K submatrices Vk Rd L according V = [V1 V2 VK] where Vk Rd L (158) The function f is then defined by the formula x Xk ℓ D V1, ζ(x) E F , . . . , D VK, ζ(x) E F ; k Dzk(x) where ℓ(y1, . . . , y K; k) denotes the cross entropy loss ℓ(y1, . . . , y K; k) = log exp (yk) PK k =1 exp (yk ) We remark that f is clearly convex and differentiable. We then recall from (19) that the kth entry of the vector y = h W,U(x) is yk = D ˆUk , W ζ(x) E F = D W T ˆUk , ζ(x) E Recalling that ˆU = ˆU1 ˆUK , we then see that the risk can be expressed as R(W, U) = f(W T ˆU) + λ W 2 F + ˆU 2 F (159) The fact that R(W, U) does not have spurious local minimizers come from the following general theorem. Published as a conference paper at ICLR 2024 Theorem G. Let g : Rm n R be a convex and differentiable function. Define φ(A, B) := g(AT B) + λ 2 A 2 F + B 2 F where A Rd m and B Rd n and assume λ > 0 and d > min(m, n). Then any local minimizer (A, B) of the function φ : Rd m Rd n R is also a global minimizer. The above theorem directly apply to (159) and shows that the risk R(W, U) does not have spurious local minimizers when λ > 0 and d > min(nw, KL). The remainder of the section is devoted to the proof of theorem G. We will follow the exact same steps as in Zhu et al. (2021), and provide the proof mostly for completeness (and also to show how the techniques from Zhu et al. (2021) apply to our case). Finally, we refer to Laurent & Brecht (2018) for a proof of theorem G in the case λ = 0. Proof of theorem G. To prove the theorem it suffices to assume that d > m without loss of generality. To see this, note that the function g(D) = g(DT ) is also convex and differentiable and note that (A, B) is a local minimum of g(AT B) + λ 2 A 2 F + B 2 F if and only if it is a local minimum of g(BT A) + λ 2 A 2 F + B 2 F So the theorem for the case d > n follows by appealing to the case d > m with the function g. So we may assume d > m. Following Zhu et al. (2021), we define the function ψ : Rm n R by ψ(D) := g(D) + D where D denote the nuclear norm of D. We then have: Claim G. For all A Rd m and B Rd n, we have that ψ(AT B) φ(A, B). Proof. This is a direct consequence of the inequality 2 A 2 F + B 2 F that we reprove here for completeness. Let AT B = UΣV T be the compact SVD of AT B. That is Σ Rr r, U Rm r, V Rn r, and r is the rank of AT B. We then have AT B = Tr(Σ) = Tr(U T AT BV ) = AU, BV F 1 2 AU 2 F + BV 2 F 1 2 A 2 F + B 2 F Computing the derivatives of φ gives φ A(A, B) = B g(AT B) T + λA and φ B (A, B) = A g(AT B) + λB (160) So (A, B) is a critical point of φ if and only if λA = B g(AT B) T (161) λB = A g(AT B) (162) Importantly, from the above we get AAT = BBT Rd d (163) Published as a conference paper at ICLR 2024 which implies that A and B have same singular values and same left singular vectors. Let U Rd d be the orthonormal matrix containing the eigenvectors of AAT = BBT . From this matrix we can construct an SVD for both A and B: A = UΣAV T A and B = UΣBV T B where ΣA Rd m and ΣB Rd n have the same singular values. From this we get the SVD of AT B, AT B = VAΣT AΣBV T B (164) and it is transparent that, AT B = A 2 F = B 2 F . (165) In particular this implies that if (A, B) is a critical point of φ, then we must have φ(A, B) = ψ(AT B). This also implies that g(AT B), AT B F = A g(AT B), B F = λ B 2 F = λ AT B (166) Using this together with the fact that the nuclear norm is the dual of the operator norm, that is C = sup G op 1 G, C F , we easily obtain: Claim H. Suppose (A, B) is a critical point of φ which satisfies g(AT B) op λ, then (A, B) is a global minimizer of φ. Proof. For any matrix C Rm n we have λ g(AT B), C AT B λ g(AT B), C F sup G op 1 G, C F = C and therefore 1 λ g(AT B) AT B . This implies that AT B is a global min of ψ. The fact that φ(A, B) = ψ(AT B) (because (A, B) is a critical point φ) together with Claim G, then implies that (A, B) is a global minimizer of φ. We now show that all local min (A, B) of φ with ker(AT ) = must satisfy g(AT B) op λ. Claim I. Suppose (A, B) is a critical point of φ which satisfies (i) ker(AT ) = g(AT B) op > λ Then (A, B) is not local min. Proof. We follow the computation from Zhu et al. (2021). Let (A, B) be a critical point of φ. Since AAT = BBT , we must have that ker(AT ) = ker(AAT ) = ker(BBT ) = ker(BT ). According to (ii) these kernels are non trivial and we may choose a unit vector z Rd that belongs to them. We then consider the perturbations d A = za T d B = zb T where a Rm and b Rn are unit vectors to be chosen later. Note that since z, a and b are unit vectors we have d A 2 F = d B 2 F = 1. Moreover, the columns of d A and d B are clearly in the kernel of AT and BT , therefore AT d A = AT d B = BT d A = BT d B = 0. This implies that all the cross terms disappear when expanding the expression: (A + εd A)T (B + εd B) = AT B + ε2d AT d B = AT B + ε2ab T We also have A + εd A 2 F = A 2 F + εd A 2 F = A 2 F + ε2 Published as a conference paper at ICLR 2024 and similarly, B + εd B 2 F = B 2 F + ε2. We then get φ(A + εd A, B + εd B) = g (A + εd A)T (B + εd B) + λ 2 A + εd A 2 F + B + εd B 2 F = g(AT B + ε2ab T ) + λ 2 A 2 F + B 2 F + λε2 = h g(AT B) + f(AT B), ε2ab T F + O(ε4) i + λ 2 A 2 F + B 2 F + λε2 = φ(A, B) + ε2 g(AT B), ab T F + λ + O(ε4) Let G = f(AT B) Rm n. We want to choose the unit vectors a and b that makes G, ab T F as negative as possible. The best choice is to choose a and b to be the first left and right singular vectors of G since this give the negative of the best rank-one approximation of G. So we choose a Rm and b Rn such that Gb = σ1a, and therefore ab T , G F = Tr(ba T G) = Tr(a T Gb) = σ1 which gives φ(A + εd A, B + εd B) = φ(A, B) + ε2 g(AT B) op + λ + O(ε4) and (ii) implies that (A, B) is not a local min. Combining the previous two claims we can easily prove the theorem. Indeed, if d > m, then the kernel of AT is nontrivial and (i) is always satisfied. As a consequence, if (A, B) is a local min of φ, then g(AT B) op λ, and therefore (A, B) must be a global min of φ.