# memory_mosaics__d6031e36.pdf Published as a conference paper at ICLR 2025 MEMORY MOSAICS Jianyu Zhang , Niklas Nolte , Ranajoy Sadhukhan , Beidi Chen , L eon Bottou FAIR, Meta Carnegie Mellon University New York University Memory Mosaics are networks of associative memories working in concert to achieve a prediction task of interest. Like transformers, memory mosaics possess compositional capabilities and in-context learning capabilities. Unlike transformers, memory mosaics achieve these capabilities in comparatively transparent way ( predictive disentanglement ). We illustrate these capabilities on a toy example and also show that memory mosaics perform as well or better than transformers on medium-scale language modeling tasks. 1 Introduction This paper presents a learning system architecture, Memory Mosaics, in which multiple associative memories work in concert to carry out a prediction task of interest. Such systems are closely related to memory networks (Weston et al., 2014; Sukhbaatar et al., 2015) and resemble transformers (Vaswani et al., 2017) despite significant differences. Like transformers, Memory Mosaics possesses some of the disentanglement and compositional capabilities that have long eluded machine learning systems (Lake & Baroni, 2018). Unlike transformers whose internal mechanism are hard to decipher (Olsson et al., 2022; Bietti et al., 2024), Memory Mosaics achieve these capabilities in comparatively transparent ways. The three main contributions of this work are (a) defining an architecture that exploits the direct similarity between self-attention and associative memories implemented with kernel regression, (b) identifying and illustrating the predictive disentanglement principle which explains how training decomposes the overall task in interesting ways, and (c) showing that this comparatively transparent architecture matches the i.i.d. performance of decoding transformers on a language modeling task, and outperforms them on o.o.d. tasks such as in-context learning. Section 2 reviews related work. Section 3 describes simple associative memory units than can be inserted in a deep network. Section 4 explains how training such a network splits a prediction task into disentangled sub-tasks. Section 5 illustrates this predictive disentanglement using a network with only 54 parameters, showing that this is not a mysterious effect of scale but a property of the architecture. Section 6 extends these ideas to fully formed memory mosaics. Section 7 reports on medium-scale language modeling experiments. 2 Related Work Several recent papers (e.g., Katharopoulos et al., 2020; Peng et al., 2023; Sun et al., 2023; Gu & Dao, 2023) propose transformer alternatives that use efficient recurrences to cut the quadratic computational cost of transformers. Closer to our interests, other authors (e.g., Ramsauer et al., 2020; Krotov, 2023; Hoover et al., 2024) rethink transformers with Hopfield-style associative memories and their associated energy function. In contrast, we leverage elementary associative memories that interpolate stored key/value pairs with a kernel regression (therefore incurring a quadratic runtime cost) in order to construct an architecture that remains very close to standard transformers but cast a new light on properties that play an important role in their compositional learning capabilities. Closely related to predictive disentanglement, (Bengio et al., 2019) proposes a meta-learning training objective that achieves causal disentanglement by seeking quick adaptation to new distributions. We argue that a similar effect happens in our architecture, as a consequence of the normal training process interpreted as a meta-learning process, revealing an important aspect of the still mysterious compositional learning abilities of transformer-like architectures. Published as a conference paper at ICLR 2025 !""#$%&'%()*+),#-. /)&'0-)*12'-&$'#- !"#$%&%'()&*)+')* ,$%#$%&%'()&*)+')* !%&'()*+$","#$!%&'!"(&&)$*+" ,&-#.-/"$*!&"0!"1"2'!"'3$*+" 454&-6"4.!7%538"" -).&%'()&*)+')* -59-535*!3"!%5"-575*!"9.3!" &,"!%5"$*9'!"!$45"35-$538 #" - .#+!"/ !"$%/ 0 , /01$)&%'()&*)+')* -59-535*!3"!%5"*5.-",'!'-5"&," !%5"$*9'!"!$45"35-$538 $" - 1#+!"&%/ !"/ !"$%/ 0 , 2**34'0%'5)&()(3+. -5!-$5:53"9.3!"9.$-3" #!/ $! #$!%"2+#"/ #!,"34.((8 Figure 1: Elementary memory unit. The keys k T are computed as a function of past observations (xt)t T . The values v T peek into the future. In this example, the value also depend on the next observation x T +1. At time T, the associative memory uses the known key k T to compute an estimate y T of E(v T |k T ) using only the previously stored pairs (kt, vt), t < T. One time step later, the input x T +1 is revealed, the value v T can be computed, and the pair (k T , v T ) is added to the memory. Associative memory Generally speaking, an associative memory is a device that can store keyvalue pairs and retrieve values given a corresponding key. This definition omits important details about dealing with duplicate keys and approximate matches. For our purposes, both keys and values shall be vectors in Rd. The retrieval process can then be represented as a function of the queried key k and all the stored pairs (k1, v1) . . . (kn, vn). Rd Rd k 7 f k; {(k1, v1) . . . (kn, vn)} Except perhaps when duplicate keys are involved, an associative memory stores key-value pairs without consideration for their temporal ordering. Therefore the retrieval function can be assumed invariant with respect to any permutation of the stored pairs. This exchangeability property suggests that we can also view an associative memory as a device that estimates a conditional probability distribution P(V |K) on the basis of the sample (k1, v1) . . . (kn, vn) of key-value pairs. The retrieval function is then a conditional expectation over this estimated distribution: f k; {(k1, v1) . . . (kn, vn)} = E(V | K = k) . (1) Such a conditional expectation can be constructed with Gaussian kernel regression,1 f k; {(k1, v1) . . . (kn, vn)} = 1 Z e β k ki 2vi with Z = i=1 e β k ki 2 . (2) The close connection between this Gaussian kernel smoothing and attention (Bahdanau et al., 2015) is obvious when all key vectors ki share a same squared norm because expression (2) becomes f k; {(k1, v1) . . . (kn, vn)} = e β k ki Pn j=1 e β k kj vi . (3) There are of course more advantageous ways to implement associative memories. Although some will certainly prove useful in the future, this paper only relies on associative memories implemented with Gaussian kernel smoothing, not least because that makes it easy to compute gradients. Predicting with associative memories Consider now a sequence (xt) of observations, discrete tokens or continuous values. We would like to leverage the past observations (xt)t T to predict some useful property of the future observations (xt)t>T . For instance we might want to predict the next observation x T +1 to construct an auto-regressive model of the sequence. 1Expression (2) is known as the Nadaraya-Watson estimator (Nadaraya, 1964; Watson, 1964). It is known to converge to the true conditional expectation E(K|V ) when n and β = n. Published as a conference paper at ICLR 2025 Our elementary memory unit (Figure 1) consists of an associative memory and a trainable feature extractor that computes suitable keys and values for the memory. The keys k T are computed as a function of the past observations (xt)t T and trainable weights w, k T = φ(x T , x T 1, . . . ; w) . (4) In contrast, the values v T are allowed to peek in the future because they represent what the memory module aims to predict. For instance, the systems described in this paper merely allow values to depend on the next observation x T +1, v T = ψ(x T+1, x T , x T 1, . . . ; w) . (5) The memory units operate independently at inference time. They start empty at the beginning of each input sequence. At time step T, each memory receives a key vector k T computed from the recent inputs (x T , x T 1, . . . ) and interpolates a response yt on the basis of the previously stored key/value pairs. The value v T is computed one time step later when the next input x T +1 is revealed and the pair (k T , v T ) is added to the memory. Although the value v T depends on the near future, the output y T does not depend on v T but merely leverages the previously stored key/value pairs to estimate v T . Therefore there is no leak of future information: each memory unit is a little machine that predicts a bit of future information (described by v T ) on the basis of recent information (described by k T ) and previously stored key/values pairs. The exact form of the feature extraction functions can vary in complexity. For instance, when each observation x T carries sufficient information, the keys k T and values v T can be computed as linear functions of respectively x T and x T +1, that is k T = Wφ x T and v T = Wψ x T +1. However we find useful to consider feature extraction functions that summarize the recent past using short convolutions or quickly vanishing leaky averages. For instance, the language experiments of Section 7 use feature extractors of the following form:2 k T = Norm k T leaky average over t = T, T-1. . . , 1 z }| { k T = k T + λφ k T 1 k T = Wφ x T v T = Norm with v T = v T + λψ v T +1 v T = Wψ x T | {z } convolution over t=T and T+1 Since this expression produces keys with unit norm (Norm(x) = x/ x ), the effective kernel bandwidth is determined by the trainable parameter β in equation (3). Training networks of memory units Consider now a deep network whose architecture includes layers of associative memory units. When the associative memories are implemented with differentiable kernel smoothing mechanisms, training such a deep network is simply a matter of unrolling the network in time and back-propagating the gradients, in ways that users of modern deep learning software will find very familiar. Unsurprisingly, unrolling equation (3) along an input sequence (x1 . . . x D) of duration D yields an expression that very much resembles masked self-attention (Vaswani et al., 2017). T {1 . . . D} y T = eβ k T ki PT 1 j=1 eβ k T kj vi , (7) Implementing associative memories with kernel smoothing therefore provides a particularly direct illustration of the connection between self-attention and associative memories (e.g., (Ramsauer et al., 2020)). However, Memory Mosaics differ because the value extraction function is allowed to peek into the near future of the input time series (xt). This slight change has important consequences Each memory unit operates as a little predictor whose outputs y T can be interpreted as a conditional expectation (1) that estimates features of the near future (v T ) of the input time series on the basis of its past observations (k T ). The parameters of the value extraction function (ψ) specify what is being predicted and the parameters of the key extraction function (φ) specify how it is predicted. 2The leaking average in expression (6) is far too simple to effectively encode long range dependencies as demonstrated in (Voelker et al., 2019; Peng et al., 2023; Gu & Dao, 2023). Published as a conference paper at ICLR 2025 Figure 2: The curve plots the prediction losses for all training sequence indices t {1 . . . D} in the training sequence. Minimizing their sum the area under the curve favors memories that produce useful value estimates after fewer time steps. Equation (7) must therefore account for the number of future time steps needed to compute v T . In our experiments, for example, v T can look one step ahead in the future. This amounts to having a more aggressive attention mask. Therefore the main diagonal must be excluded from the attention mask, justifying the T 1 upper bound in the sum.3 Because each memory unit acts as a predictor, a single layer of memory units is sufficient to address the induction head problem of Bietti et al. (2024). In contrast, a decoding transformer needs at least two self-attention layers for the same task. Equation (7) makes no provision for position encoding and no distinction between query and key vectors. In other words, we are betting that these transformers complications are no longer needed because our associative memory units do not need them to implement induction heads. 4 Predictive Disentanglement Training and meta-learning The training process determines which future bit of information is predicted by each associative memory unit (through the parameters that control the computation of the values v T ) and which kernels are used to perform the predictions (through the parameters of that control the computation of the keys k T ). In contrast, the relation between keys and predicted values is determined for each input sequence at inference time through the memorization of key/values pairs specific to each sequence. The training procedure should therefore be seen as a meta-learning process, distinct from the memory-based learning that occurs at inference time when new key/value pairs are added into the memories. Predictive disentanglement This meta-learning interpretation reveals a remarkable phenomenon that we call predictive disentanglement : the gradient training algorithm splits the overall prediction task (e.g., predicting the next token in a natural language sentence) into disentangled prediction sub-tasks assigned to each memory unit. Consider a training set composed of long enough sequences (x1, . . . x D) extracted from underlying time series governed by possibly different stationary processes. The goal of our network is to predict each x T +1 using the previous observations x1 . . . x T . Unrolling the network in time along each sequence (x1 . . . x D) and collecting the prediction losses measured at each position t can be summarized by a curve that shows the prediction cost (or loss) at each time step 1 . . . D, as illustrated in Figure 2. We can expect that the prediction cost observed at position T becomes smaller when T increases because more information (x1 . . . x T ) is available to predict each x T +1. The training process minimizes the total prediction cost, that is the area under the curve in Figure 2 viewed as a collection of vertical slices. We can also view this area as a collection of horizontal slices, each representing the context length required to drive the prediction cost below a certain threshold. Therefore the training process can also be viewed as minimizing the context length needed to produce good enough predictions. Because the associative memory retrieval function (2) is known to converge to stationary conditional expectations E(V |K), each memory unit is driven to produce a good conditional expectation estimate as soon as possible. This can be achieved in two ways: Let us first assume that each memory unit has a frozen value extraction function ψ. The training procedure can still make each memory unit statistically more efficient by tuning the parameters 3One could of course use a more aggressive masking to allow v T peeking several time steps in the future. Published as a conference paper at ICLR 2025 of the key extraction function φ, that is, by learning how to compare the current prediction context (x T , x T 1, x T 2 . . . ) with past prediction contexts (xt, xt 1, xt 2 . . . ) for t < T. Learning a similarity metric (a kernel) is a well known way to make non-parametric estimators more efficient (e.g., Bach et al., 2004). For instance, the training procedure can construct keys that summarize the relevant contextual information, discarding noise factors that could increase the distance between keys associated with similar values. It can also adjust the effective kernel bandwidth, for instance, using parameter β in equation (7). When multiple memory units are available, the training procedure can also distribute the overall prediction task among the available memory units. As long as the memory units outputs can still be combined to address the overall task, the training algorithm can optimize the parameters of the value extraction functions ψ to produce values v T that more efficiently modeled by their respective memory units. Because each memory unit operates independently at inference time, this works best when the overall prediction task is disentangled into smaller prediction sub-tasks that can be modeled independently and efficiently. More precisely, the sub-tasks must be chosen so that each memory can carry out its assigned modeling task at inference time without having to account for the combined impact of the operation of all memory units. Their outputs can then be recombined to provide predictions for inputs that are globally very different from the training inputs, but whose disentangled components are individually predictable, as illustrated in Section 5. Disentanglement has long been recognized as desirable (Bengio, 2013) but has been hard to pinpoint (Comon, 1994; Roth et al., 2022; Thomas et al., 2018). Predictive disentanglement is closely related to the meta-transfer objective of Bengio et al. (2019) but arises as a side effect of a specific predictive architecture trained with the usual gradient procedure. Although predictive disentanglement is easier to understand in the case of a network of associative memory units, we conjecture that something similar also occurs in standard transformers. 5 Tracking three moons We give an illustrative example of predictive disentanglement: three moons orbit a remote planet. Although the local astronomers are very far from understanding celestial mechanics,4 they nevertheless observe periodic motions and debate how to predict future moon positions. A first astronomer proposes to compile a single table containing the daily positions of all three moons, arguing that if the current set of moon positions matches a previous observation, the future moon positions will match the following observations. A second astronomer suggests instead to make three tables, one for each moon, arguing that the future positions of each moon can be independently predicted by matching its current position with a previously observed one. To make reliable predictions, the first astronomer needs a table that contains at least one record for each of the possible moon configurations. Our astronomer therefore needs to log the daily moon positions until all three moons return to their original configuration, after a number of days equal to the least common multiple lcm(p1, p2, p3) of the individual moon periods. In contrast, the second astronomer only needs to log daily moon positions until each of the moons returns to a previously observed position, for a number of days equal to the period max(p1, p2, p3) of the slowest moon. One could argue that the proposal of the second astronomer is obviously superior because the three moons are distinct objects, well separated in space and time. One could instead argue that we view the moons as separate objects precisely because their respective futures can in general be independently predicted. Space and time separation merely suggests the possibility of independent predictions, as long as the moons do not collide. Model For our purposes, each observation xt consists of three complex numbers eiθk that encode the angular positions θk of the three moons inside their respective orbital plane. We consider two single layer models (Figure 3) with either Nh = 1 or Nh = 3 memory units whose added dimensions match the input dimension. The trainable parameters of the linear key and value extraction are collected in two 3 3 complex matrices Wφ and Wψ. The memory unit follow equation (3) with a fixed parameter β = 50. A third 3 3 complex matrix Wz combines the memory unit predictions into an output z T that hopefully predicts x T +1. Both networks share an interesting analytic solu- 4We do not seek to discuss subtleties such as elliptical orbits or multi-body problems. Our primitive astronomers are best compared to the ancient sky watchers whose efforts eventually gave the Ptolemaic model. Published as a conference paper at ICLR 2025 Nh = 1 or 3 Stack h=1...Nh h k(h) T i = Wφ x T Wφ C3 3 Stack h=1...Nh h v(h) T i = Wψ x T +1 Wψ C3 3 y(h) t = 1 ZT t