# wasserstein_distances_neuronal_entanglement_and_sparsity__be43661a.pdf Published as a conference paper at ICLR 2025 WASSERSTEIN DISTANCES, NEURONAL ENTANGLEMENT, AND SPARSITY Shashata Sawmya1 , Linghao Kong1 , Ilia Markov2, Dan Alistarh2,3,4, & Nir Shavit1,3,4 1MIT 2IST Austria 3Neural Magic 4Red Hat {shashata, linghao, shanir}@mit.edu, {ilia.markov, dan.alistarh}@ist.ac.at Disentangling polysemantic neurons is at the core of many current approaches to interpretability of large language models. Here we attempt to study how disentanglement can be used to understand performance, particularly under weight sparsity, a leading post-training optimization technique. We suggest a novel measure for estimating neuronal entanglement: the Wasserstein distance of a neuron s output distribution to a Gaussian. Moreover, we show the existence of a small number of highly entangled Wasserstein Neurons in each linear layer of an LLM, characterized by their highly non-Gaussian output distributions, their role in mapping similar inputs to dissimilar outputs, and their significant impact on model accuracy. To study these phenomena, we propose a new experimental framework for disentangling polysemantic neurons. Our framework separates each layer s inputs to create a mixture of experts where each neuron s output is computed by a mixture of neurons of lower Wasserstein distance, each better at maintaining accuracy when sparsified without retraining. We provide strong evidence that this is because the mixture of sparse experts is effectively disentangling the input-output relationship of individual neurons, in particular the difficult Wasserstein neurons. 1 INTRODUCTION Disentangling polysemantic neurons into their component, human-understandable features has been a longstanding goal of machine learning interpretability research (Olah et al., 2020; Jermyn et al., 2022; Elhage et al., 2022; Gurnee et al., 2023; Templeton, 2024; Gurnee et al., 2024). While neurons are the basic building blocks of neural network architectures, they do not map one-to-one with specific features. Instead, neurons frequently engage in polysemantic representations, where they are activated by multiple, unrelated concepts and detect diverse features (Arora et al., 2018; Mu & Andreas, 2020). It is suspected that every neuron is polysemantic to some degree (Lecomte et al., 2023), and so we will refer to all neurons as polysemantic in this work. Due to the importance of highly polysemantic neurons in a network s computation (Bricken et al., 2023), the question of whether these neurons require more parameters naturally arises. However, the effects of polysemanticity on network performance under weight sparsity has not been well explored. Weight sparsification (Hoefler et al., 2021) aims to reduce the number of executed parameters in large language models (LLMs) by setting certain weight values to zero to improve efficiency. Various sparsification algorithms have been developed for this process (Han et al., 2015; Sun et al., 2023; Frantar & Alistarh, 2023). This paper investigates the relationship between an individual neuron s degree of entanglement (which we will formally define in a later section) and its ability to be sparsified in real-world models. To the best of our knowledge, this is the first work to explore this crucial perspective of entanglement-dependent model sparsification. To better understand the impact of entanglement on sparsification, we introduce a novel metric that quantifies a neuron s degree of entanglement. This metric is the Wasserstein distance between a *Equal contribution. Author order determined by coin toss. 1Code available at https://github.com/Shavit-Lab/Sparse-Expansion. Published as a conference paper at ICLR 2025 neuron s output distribution and a Gaussian (Equation 1). We find that neurons with a particularly high Wasserstein distance (Figure 1d, A8d) are crucial for the performance of a network and very sensitive to pruning. We provide evidence that a neuron s Wasserstein distance is related to its ability to distinguish similar inputs to different outputs through its dot product, and we refer to these neurons as especially entangled (Equation 2). Akin to previous works investigating special types of neurons (Gurnee et al., 2023; Stolfo et al., 2024; Gurnee et al., 2024), this work explores the role of crucial neurons with implications for interpretability, specifically in the context of network sparsity. Figure 1: The output distributions of neurons in Llama-2-7B computed densely and at 90% sparsity on Wikitext-2. WD refers to the Wasserstein distance of the output distribution to a Gaussian. RI refers to the relative improvement of Sparse Expansion over Sparse GPT. (a) The dense output distribution of a random neuron with a WD of 0.050 is well captured by Sparse GPT, and (b) expanding this neuron via Sparse Expansion imparts only a small (18%) increase in performance. (c) The cluster outputs are all concentrated in close proximity to each other. (d) Sparse GPT struggles to capture the dense distribution of an entangled neuron with a WD of 0.524. (e) Following expansion, the sparse output of the entangled neuron is much better captured, leading to more improvement (77%). (f) Each expert specializes over a different portion of the distribution. To analyze the phenomenon of neuronal superposition under sparsity in greater detail, we create an experimental framework, which we dub Sparse Expansion. It expands a model into a mixture of sparse experts by clustering input embeddings layer-wise. Based on this clustering, Sparse Expansion utilizes the input-aware nature of the Sparse GPT (Frantar & Alistarh, 2023) pruning algorithm to specialize different sparse experts to different sets of inputs, starting from the same base weights. Through Sparse Expansion, we are able to analyze the entangled neurons in much more detail, since now different subgroups of the inputs are being computed with different edges (Figure 1f, A8f). We find that as a neuron lose edges, its output distribution tends to shift toward a Gaussian distribution (Figure A9). However, through Sparse Expansion, the original output distribution can be better preserved under sparse computation (Figure 1e, A8e). We relate our findings to recent theoretical work on the bounds of neural computation under superposition (H anni et al., 2024; Adler & Shavit, 2024). Our main technical contribution is a detailed study of how model accuracy under sparsity is related to its degree of neuronal entanglement. In every LLM, there exist neurons that have striking, irregular output distributions (Figure 2c, A1). These neurons have an outsized effect on model performance and seem to be responsible for differentiating similar input vectors (Figure 2). We believe that the existence of these neurons is a manifestation of polysemanticity in real-world language models. We find that the Wasserstein distance to a Gaussian is a strong indicator of such neurons. In the next section we explain such Wasserstein neurons , neuronal entanglement, and the implication of ablating Wasserstein neurons in LLMs in detail. We then formulate our experimental framework Sparse Expansion and show how to effectively disentangle the input-output relationship of neurons through Sparse Expansion, as well as some empirical computational bounds. Finally, we present some results showing its performance relative to other state-of-the-art one-shot compression techniques in the hopes of inspiring future sparsification algorithms. 2 WASSERSTEIN NEURONS 2.1 CHARACTERIZING NON-GAUSSIAN NEURONAL OUTPUT DISTRIBUTIONS We investigate the output distributions of individual neurons in all linear layers of transformer feedforward networks (FFNs) during inference. Specifically, consider a linear operation Y = W X +b, where Y Rn s is the output matrix, W Rn m is the weight matrix, b Rn is the bias vector, Published as a conference paper at ICLR 2025 broadcasted across all neurons, and X Rm s is the input matrix, where each column represents an input vector. Each neuron is an individual row of W , and we collect individual scalar elements from the corresponding row in Y as the output distribution for that neuron. We focus our analysis in Pythia-1.4B (Biderman et al., 2023), Llama-2-7B (Touvron et al., 2023), and Llama-3-8B (Dubey et al., 2024). Most neurons exhibit a reasonably Gaussian output distribution after their dot product with the input vector (Figure 1a, 2a). However, we find the existence of a small group neurons with highly non-Gaussian outputs (Figure 1d, 2c) in all FFNs (Figure A1). To characterize the degree of difference in terms of the shape of these distributions the non Gaussian output distributions of certain neurons with the Gaussian-like output distribution of most neurons we considered several metrics, such as entropy. However, the Wasserstein distance (WD) (Kantorovich, 2006; Villani et al., 2009) proved to be the most effective metric for quantifying this difference. In optimal transport theory, the WD measures the minimal transportation cost between two distributions, taking their geometry in real space into account. To find the WD of every neuron to the Gaussian N, we crucially first normalize the output distributions of each neuron n to have zero mean and unit variance, and compare this normalized distribution n to N(0, 1). This normalization is performed because the range of neuron output distributions is quite variable, and we wanted to prioritize the differences in the shape of the distributions, rather than other properties. We use the 1-Wasserstein distance in one dimension, as shown in Equation 1. W1(n , N) = Z 1 0 |F 1(z) φ 1(z)|dz. (1) F 1 and φ 1 are the inverse cumulative distribution function of n and N(0, 1), respectively, which can be approximated with empirical data. To compute the WD of every neuron efficiently, we use the Sci Py implementation (Virtanen et al., 2020). When computing the difference metric in this way, we find that our originally observed neurons (Figure 1d, A8d) have been designated correctly with high WD to N. We thus term these neurons Wasserstein neurons. We also observe little overlap between neurons with high mean weight magnitudes and Wasserstein neurons (Figure A4a). We additionally analyze Pythia-1.4B across its training, from network initialization to the final step. We find that Wasserstein neurons do not seem to receive more weight updates than other neurons (Figure A2a). Interestingly, we also find that Wasserstein neurons arise relatively early on in training, within 10-20 billion tokens (Figure A2b). This phenomenon is likely related to and a manifestation of other observations that fundamental model training dynamics rapidly stabilize, such as the rank of the gradient or the largest eigenvalue of the loss hessian (Gur-Ari et al., 2018; Zhao et al., 2024; Noci et al., 2024). We leave further investigations into this crucial training period to future work. 2.2 WASSERSTEIN NEURONS AND ENTANGLEMENT Here, we define and study the notion of entanglement of these Wasserstein neurons in greater detail by positing a new avenue to investigate entanglement. According to superposition theory, as the number of features increases relative to the number of neurons, features are forced to become nonorthogonal in order to represent more of them, thus increasing entanglement (Elhage et al., 2022). Consider neurons that must attend to multiple of these features. As the number of features increases, and different features are forced to become more similar in direction, such neurons must still manage to distinguish between them. Therefore, in this context, neurons that are highly entangled have the task of differentiating between similar input vectors, and mapping them to different output values. To mathematically explore this concept, we study the input-output (IO) relationship of individual neurons. We introduce the metric mapping difficulty (MD), which measures how often a neuron must generate dissimilar outputs from similar inputs through its dot product computation. The MD for a particular neuron, given its weights and a set of inputs, is calculated as follows (Equation 2): MD(w, X) = mean 1 i