# attention_as_implicit_structural_inference__10e65d57.pdf Attention as Implicit Structural Inference Ryan Singh School of Engineering and Informatics, University of Sussex. rs773@sussex.ac.uk Christopher L. Buckley School of Engineering and Informatics, University of Sussex. VERSES AI Research Lab, Los Angeles, CA, USA. Attention mechanisms play a crucial role in cognitive systems by allowing them to flexibly allocate cognitive resources. Transformers, in particular, have become a dominant architecture in machine learning, with attention as their central innovation. However, the underlying intuition and formalism of attention in Transformers is based on ideas of keys and queries in database management systems. In this work, we pursue a structural inference perspective, building upon, and bringing together, previous theoretical descriptions of attention such as; Gaussian Mixture Models, alignment mechanisms and Hopfield Networks. Specifically, we demonstrate that attention can be viewed as inference over an implicitly defined set of possible adjacency structures in a graphical model, revealing the generality of such a mechanism. This perspective unifies different attentional architectures in machine learning and suggests potential modifications and generalizations of attention. Here we investigate two and demonstrate their behaviour on explanatory toy problems: (a) extending the value function to incorporate more nodes of a graphical model yielding a mechanism with a bias toward attending multiple tokens; (b) introducing a geometric prior (with conjugate hyper-prior) over the adjacency structures producing a mechanism which dynamically scales the context window depending on input. Moreover, by describing a link between structural inference and precisionregulation in Predictive Coding Networks, we discuss how this framework can bridge the gap between attentional mechanisms in machine learning and Bayesian conceptions of attention in Neuroscience. We hope by providing a new lens on attention architectures our work can guide the development of new and improved attentional mechanisms. 1 Introduction Designing neural network architectures with favourable inductive biases lies behind many recent successes in Deep Learning. The Transformer, and in particular the attention mechanism has allowed language models to achieve human like generation abilities previously thought impossible [1, 2]. The success of the attention mechanism as a domain agnostic architecture has prompted adoption across a diverse range of tasks beyond language modelling, notably reaching state-of-the-art performance in visual reasoning and segmentation tasks [3, 4]. This depth and breadth of success indicates the attention mechanism expresses a useful computational primitive. Recent work has shown interesting theoretical links to kernel methods [5, 6, 7], Hopfield networks [8], and Gaussian mixture models [9, 10, 11, 12, 13], however a formal understanding that captures the generality of this computation remains outstanding. In this paper, we show the attention mechanism can naturally be described as inference on the structure of a graphical model, agreeing with observations that transformers are able to flexibly choose between models based on context [14, 15]. This Bayesian perspective complements previous theory [16, 8, 12], adding new 37th Conference on Neural Information Processing Systems (Neur IPS 2023). methods for reasoning about inductive biases and the functional role of attention variables. Further, understanding the core computation as inference permits a unified description of multiple attention mechanisms in the literature as well as narrowing the explanatory gap to ideas in neuroscience. This paper proceeds in three parts: First in Sec.3, we show that soft attention mechanisms (e.g. self-attention, cross-attention, graph attention, which we call transformer attention hereafter) can be understood as taking an expectation over possible connectivity structures, providing an interesting link between softmax-based attention and marginal likelihood. Second in Sec.4, we extend the inference over connectivity to a Bayesian setting which, in turn, provides a theoretical grounding for iterative attention mechanisms (slot-attention and block-slot attention) [17, 18, 19], Modern Continuous Hopfield Networks [8] and Predictive Coding Networks. Finally in Sec.5, we leverage the generality of this description in order to design new mechanisms with predictable properties. Intuitively, the attention matrix can be seen as the posterior distribution over edges E in a graph, G = (K Q, E) consisting of a set of query and key nodes Q, K each of dimension d. Where the full mechanism computes an expectation of a function defined on the graph V : G Rd |G| with respect to this posterior. Attention(Q, K, V ) = p(E | Q, K) z }| { softmax(QWQW T KKT = Ep(E|Q,K)[V ] Crucially, when G is seen as a graphical model, the posterior over edges becomes an inference about dependency structure and the functional form becomes natural. This formalism provides an alternate Bayesian theoretical framing within which to understand attention models, shifting the explanation from one centred around retrieval to one that is fundamentally concerned with in-context inference of probabilistic relationships (including retrieval). Within this framework different attention architectures can be described by considering different implicit probabilistic models, by making these explicit we hope to support more effective analysis and the development of new architectures. 2 Related Work A key benefit of the perspective outlined here is to tie together different approaches taken in the literature. Specifically, structural variables can be seen as the alignment variables discussed in previous Bayesian descriptions [16, 20, 21], on the other hand Gaussian Mixture Models (GMMs) can be seen as a specific instance of the framework developed here. This description maintains the explanatory power of GMMs by constraining the alignment variables to be the edges of an implicit graphical model, while offering the increased flexibility of alignment approaches to describe multiple forms of attention. Latent alignment and Bayesian Attention, several attempts have been made to combine the benefits of soft (differentiability) and stochastic attention, often viewing attention as a probabilistic alignment problem. Most approaches proceed by sampling, e.g., using the REINFORCE estimator [20] or a top K approximation [22]. Two notable exceptions are [16] which embeds an inference algorithm within the forward pass of a neural network, and [21] which employs the re-parameterisation trick for the alignment variables. In this work, rather than treating attention weights as an independent learning problem, we aim to provide a parsimonious implicit model that would give rise to the attention weights. Additionally showing that soft attention weights arise naturally in variational inference from either collapsed variational inference or a mean-field approximation. Relationship to Gaussian mixture model, previous works that have taken a probabilistic perspective on the attention mechanism note the connection to inference in a gaussian mixture model [11, 10, 12, 13]. Indeed [23] directly show the connection between the Hopfield energy and the variational free energy of a Gaussian mixture model. Although Gaussian mixture models, a special case of the framework we present here, are enough to explain cross attention they do not capture slot or self-attention, obscuring the generality underlying attention mechanisms. In contrast, the description presented here extends to structural inductive biases beyond what can be expressed in a Gaussian mixture model, additionally offering a route to describing the whole transformer block. Attention as bi-level optimisation, mapping feed-forward architecture to a minimisation step on a related energy function has been called unfolded optimisation [24]. Taking this perspective can lead to insights about the inductive biases involved for each architecture. It has been shown that the cross-attention mechanism can be viewed as an optimisation step on the energy function of a form of Hopfield Network [8], providing a link between attention and associative memory. while [25] extend this view to account for self-attention. Our framework distinguishes Hopfield attention, which does not allow an arbritary value matrix, from transformer attention. Although there remains a strong theoretical connection, we interpret the Hopfield Energy as an instance of variational free energy, aligning more closely with iterative attention mechanisms such as slot-attention. 3 Transformer Attention 3.1 Attention as Expectation We begin by demonstrating transformer attention can be seen as calculating an expectation over graph structures. Specifically, let x = (x1, .., xn) be observed input variables, ϕ be some set of discrete latent variables representing edges in a graphical model of x given by p(x | ϕ), and y a variable we need to predict. Our goal is to find Ey|x[y], however the graph structure ϕ is unobserved so we calculate the marginal likelihood. Ey|x[y] = X ϕ p(ϕ | x)Ey|x,ϕ[y] Importantly, the softmax function is a natural representation for the posterior, p(ϕ | x) = p(x, ϕ) P ϕ p(x, ϕ) = softmax(ln p(x, ϕ)) in order to expose the link to transformer attention, let the model of y given the graph (x, ϕ) be parameterised by a function Ey|x,ϕ[y] = v(x, ϕ). Ey|x[y] = X ϕ softmax(ln p(x, ϕ))v(x, ϕ) = Eϕ|x[v(x, ϕ)] (1) In general, transformer attention can be seen as weighting v(x, ϕ) by the posterior distribution p(ϕ | x) over different graph structures. We show Eq.1 is exactly the equation underlying self and cross-attention by presenting the specific generative models corresponding to them. In this description the latent variables ϕ are identified as edges between observed variables x (keys and queries) in a pairwise Markov Random Field, parameterised by matrices WK and WQ, while the function v is parameterised by WV . Pairwise Markov Random Fields are a natural tool for modelling the dependencies of random variables, with prominent examples including Ising models (Boltzmann Machine) and multivariate Gaussians. While typically defined given a known structure, the problem of inferring the latent graph is commonly called structural inference. Formally, given a set of random variables X = (Xv)v V with probability distribution [p] and a graph G = (V, E). The variables form a pairwise Markov Random Field (p MRF) [26] with respect to G if the joint density function P(X = x) = p(x) factorises as follows where Z is the partition function ψv(xv) and ψe = ψu,v(xu, xv) are known as the node and edge potentials respectively. Bayesian structural inference also requires a structural prior p(ϕ) over the space of possible adjacency structures, ϕ Φ, of the underlying graph. Factorisation, without constraints this space grows exponentially in the number of nodes (2|V | possible graphs leading to intractable softmax calculations), all the models we explore here implicitly assume a factorised prior1. We briefly remark that Eq.1 respects factorisation of [p] in the following 1Additionally placing zero probability mass on much of the space, for example disconnected graphs. sense; if the distribution admits a factorisation (a partition of the space of graphs Φ = Q i Φi) with respect to the latent variables p(x, ϕ) = Q i efi(x,ϕi) where ϕi Φi, and the value function distributes over the same partition of edges v(x, ϕ) = P i vi(x, ϕi) then each of the factors can be marginalised independently: Eϕ|x[v(x, ϕ)] = X i Eϕi|x[vi] (2) To recover cross-attention and self-attention we need to specify the structural prior, potential functions and a value function. (In order to ease notation, when Φi is a set of edges involving a common node xi, such that ϕi = (xi, xj) represents a single edge, we use the notation ϕi = [j], suppressing the shared index.) 3.2 Cross Attention and Self Attention We first define the model that gives rise to cross-attention: Key nodes K = (x1, .., xn) and query nodes Q = (x 1, ..., x m) Structural prior p(ϕ) = Qm i=1 p(ϕi), where Φi = {(x1, x i), .., (xn, x i)} is the set of edges involving x i and ϕi Uniform(Φi) such that each query node is uniformly likely to connect to each key node. Edge potentials ψ(xj, x i) = x T i W T QWKxj, in effect measuring the similarity of xj and x i in a projected space. Value functions vi(x, ϕi = [j]) = WV xj, a linear transformation applied to the node at the start of the edge ϕi. Taking the expectation with respect to the posterior in each of the factors defined in Eq.2 gives the standard cross-attention mechanism, Ep(ϕi|Q,K)[vi] = X j softmaxj(x T i W T QWKxj)WV xj If the key nodes are in fact the same as the query nodes and the prior is instead over a directed graph we recover self-attention (A.8.1). 4 Iterative Attention We continue by extending attention to a latent variable setting, where not all the nodes are observed. In essence applying the attention trick, i.e., a marginalisation of structural variables, to a variational free energy (Evidence Lower Bound). This allows us to recover models such as slot attention [17] and block-slot attention [18]. These mechanisms utilise an EM-like procedure using the current estimation of latent variables to infer the structure and then using the inferred structure to improve estimation of latent variables. Interestingly, Modern Continuous Hopfield Networks fit within this paradigm rather than the one discussed in Sec.3; collapsed variational inference produces an identical energy function to the one proposed by Ramsauer et al. [8]. 4.1 Collapsed Inference We present a version of collapsed variational inference [27], where the collapsed variables ϕ are again structural, showing how this results in a Bayesian attention mechanism. In contrast to the previous section, we have a set of (non-structural) latent variables z. The goal is to infer z given the observed variables, x, and a latent variable model p(x, z, ϕ). Collapsed inference proceeds by marginalising out the extraneous latent variables ϕ [27]: p(x, z) = X ϕ p(x, z, ϕ) (3) We define a gaussian recognition density q(z) N(z; µ, Σ) and optimise the variational free energy F(λ) = Eq[ln qλ(z) ln p(x, z)] with respect to the parameters, λ = (µ, Σ), of this distribution. Application of Laplace s method yields approximate derivatives of the variational x1 x1 x1 x1 x1 K=Q Self Attention Q, K Cross Attention Observed: X, Latent: Z zm zm zm zm zm z3 z3 z3 z3 z3 z2 z2 z2 z2 z2 z1 z1 z1 z1 z1 Observed: X, Latent: Z Slot Attention zm zm zm zm zm z3 z3 z3 z3 z3 z2 z2 z2 z2 z2 z1 z1 z1 z1 z1 Observed: X, Latent: Z Empirical Priors: M Block-Slot Attention Figure 1: Comparison of models involved in different attention mechanisms. In each case, the highlighted edges indicate Φi the support of the uniform prior over ϕi. Attention proceeds by calculating a posterior over these edges, given the current state of the nodes, before using this inference to calculate an expectation of the value function v. For iterative attention mechanisms the value function can be identified as the gradient of a variational free energy, in contrast, transformer attention uses a learnable function. free energy µF µ ln p(x, µ) and ΣF 2 µ ln p(x, µ), here we focus on the first order terms 2. Substituting in Eq.3: ϕ p(x, µ, ϕ) (4) ϕ p(x, µ, ϕ) ϕ µp(x, µ, ϕ) (5) In order to make the link to attention, we employ the log-derivative trick, substituting p( ) = eln p( ) and re-express Eq.5 in two ways: ϕ softmaxϕ(ln p(x, µ, ϕ)) µ ln p(x, µ, ϕ) (6) = Eϕ|x,µ[ µ ln p(x, µ, ϕ)] (7) The first form reveals the softmax which is ubiquitous in all attention models. The second, suggests the variational update should be evaluated as the expectation of the typical variational gradient (the term within the square brackets) with respect to the posterior over the parameters represented by the random variable ϕ. In other words, iterative attention is exactly transformer attention applied iteratively where the value function is the variational free energy gradient. We derive updates for a general p MRF before again recovering (iterative) attention models in the literature by specifying particular distributions. Free Energy of a marginalised p MRF, recall the factorised p MRF, p(x, ϕ) = 1 Z Q i efi(x,ϕi). Again, independence properties simplify the calculation, the marginalisation can be expressed as a product of local marginals, P ϕ p(x, ϕ) = 1 ϕi efi(x,ϕi). Returning to the inference setting, the nodes are partitioned into observed nodes, x, and variational parameters µ. Hence the (approximate) collapsed variational free energy Eq.5, can be expressed as, F(x, µ) = P ϕi efi(x,µ,ϕi) + C and it s derivative: F µj = X ϕi softmax(fi) fi Finally, we follow [8] in using the Convex-Concave Procedure (CCCP) to derive a simple fixed point equation which necessarily reduces the free energy. Quadratic Potentials and the Convex Concave Procedure, assuming the node potentials are quadratic ψ(xi) = 1 2x2 i and the edge potentials have the form ψ(xi, xj) = xi Wxj, and define fi = P e Φi ψe . Consider the following fixed point equation, ϕi softmax( fi) fi 2As the first order terms are independent of the second order ones, see A.7.1 for details. since (under mild conditions) node potentials are convex and edge potentials are concave (A.7.1.1), we can invoke the CCCP [28] to show this fixed point equation descends on the energy F(x, µ j) F(x, µj) with equality if and only if µ j is a stationary point of F. We follow Sec.3 in specifying specific structural priors and potential functions that recover different iterative attention mechanisms. 4.2 Modern Continuous Hopfield Network Let the observed, or memory, nodes x = (x1, .., xn) and latent nodes z = (z1, .., zm) have the following structural prior p(ϕ) = Qm i=1 p(ϕi), where ϕi Uniform{(x1, zi), .., (xn, zi)}, meaning each latent node is uniformly likely to connect to a memory node. Define edge potentials ψ(xj, zi) = z T i xj. Application of Eq.8: µ i = X j softmaxj(µT i xj)xj When µi is initialised to some query ξ the system the fixed point update is given by µ i (ξ) = Eϕi|x,ξ[x[j]]. If the patterns x are well separated, µ i (ξ) xj , where xj is the closest vector and hence can be used as an associative memory. 4.3 Slot Attention Slot attention [17] is an object centric learning module centred around an iterative attention mechanism. Here we show this is a simple adjustment of the prior beliefs on our edge set. With edge potentials of the form ψ(xj, zi) = z T i W T QWKxj, replace the prior over edges with p(ϕ) = Qn j=1 p(ϕj), ϕj Uniform{(xj, z1), .., (xj, zm)}. Notice, in comparison to MCHN, the prior over edges is swapped, each observed node is uniformly likely to connect to a latent node, in turn altering the index of the softmax. µ i = X j softmaxi(µT i W T QWKxj)W T QWKxj while the original slot attention employed an RNN to aid the basic update shown here, the important feature is that the softmax is taken over the slots . This forces competition between slots to account for the observed variables, creating object centric representations. 4.4 Predictive Coding Networks Predictive Coding Networks (PCN) have emerged as an influential theory in Computational Neuroscience [29, 30, 31]. Building on theories of perception as inference and the Bayesian brain, PCNs perform approximate Bayesian inference by minimising a variational free energy of a graphical model, where incoming sensory data are used as observations. Typical implementations use a hierarchical model with Gaussian conditionals, resulting in a local prediction error minimising scheme. The minimisation happens on two distinct time-scales, which can be seen as E-step and M-steps on the variational free energy: a (fast) inference phase encoded by neural activity corresponding to perception and a (slow) learning phase associated with synaptic plasticity. Gradient descent on the free energy gives the inference dynamics for a particular neuron µi, [32] Where ϵ are prediction errors, w represent synaptic strength, k are node specific precisions representing uncertainty in the generative model and ϕ , ϕ+ represent pre-synaptic and post-synaptic terminals resectively. Applying a uniform prior over the incoming synapses results in a slightly modified dynamics, ϕ softmax( ϵϕ 2)kϕϵϕ + X ϕ+ softmax( ϵϕ 2)kϕϵϕwϕ where the softmax function induces a normalisation across prediction errors received by a neuron. This dovetails with theories of attention as normalisation in Psychology and Neuroscience [33, 34, 35]. In contrast previous predictive coding based theories of attention have focused on the precision terms, k, due to their ability to up and down regulate the impact of prediction errors [36, 37, 38]. Here we ϵ ϵ ϵ ϵ ϵ ϵ x0 x1 x2 . . . ϵ x1 ϵ x0 ϵ . . . ϵ ϵ x2 Shuffle: Uniform(Σ) . . . ϵ x1 ϵ x0 ϵ . . . ϵ ϵ x2 . . . ϵ x1 ϵ x0 ϵ . . . ϵ ϵ x2 Figure 2: Multihop Attention: (left) Graphical description of the toy problem, x2 is generated causally from x1 and x0, which are used to generate y. (centre) Comparison of the attention employed by Multihop which takes two steps on the attention graph (top) contrasted with Self Attention (bottom). Multihop Attention has the correct bias to learn the task approaching the performance of two-layer Self Attention, while a single layer of Self Attention is unable (top right). Empirically examining the attention weights, Multihop Attention is able to balance attention across two positions, while self-attention favours a single position. see the softmax terms play a functionally equivalent role to precision variables, inheriting their ability to account for bottom-up and top-down attention, while exhibiting the fast winner-takes-all dynamics that are associated with cognitive attention. 5 New Designs By identifying the attention mechanism in terms of an implicit probabilistic model, we can review and modify the underlying modelling assumptions in a principled manner to design new attention mechanisms. Recall transformer attention can be written as the marginal probability p(y | x) = P ϕ p(ϕ | x)Ey|x,ϕ[y], the specific mechanism is therefore informed by three pieces of data: (a) the value function p(y | x, ϕ), (b) the likelihood p(x | ϕ) and (c) the prior p(ϕ). Here, we explore modifying (a) and (c) and show they can exhibit favourable biases on toy problems. 5.1 Multi-hop Attention Our description makes it clear that the value function employed by transformer attention can be extended to any function over the graph. For example, consider the calculation of Ey|x,ϕ[yi] in transformer attention, a linear transformation is applied to the most likely neighbour, xj, of xi. A natural extension is to include a two-hop neighbourhood, additionally using the most likely neighbour xk of xj. The attention mechanism then takes a different form Ep(ϕj|ϕi)p(ϕi|x)[V (xϕi + xϕj)] = (Pϕ + P 2 ϕ)V X, where Pϕ is the typical attention matrix. While containing the same number of parameters as a single-layer of transformer attention, for some datasets two-hop attention should be able to approximate the behaviour of two-layers of transformer attention. Task Setup We simulate a simple dataset that has this property using the following data generation process: Initialise a projection matrix Wy Rd 1 and a relationship matrix Wr Rd d. X is then generated causally, using the relationship xi+1 = Wrxi + N(0, σ) to generate x0, x1 and x2, while the remaining nodes are sampled from the noise distribution N(0, σ). Finally, the target y is generated from the history of x2, y = Wy(x1 + x0) and the nodes of X are shuffled. Importantly Wr is designed to be low rank, such that performance on the task requires paying attention to both x1 and x0, Figure 2. 5.2 Expanding Attention One major limitation of transformer attention is the reliance on a fixed context window. From one direction, a small context window does not represent long range relationships, on the other hand a large window does an unnecessary amount of computation when modelling a short range relationship. By replacing the uniform prior with a geometric distribution p(ϕ | q) Geo(q), ϵ ϵ ϵ ϵ ϵ ϵ ϵ x1 x2 . . . ϵ ϵ ϵ x1 ϵ ϵ ϵ ϵ x2 Shuffle pos(x1) Geo(p) . . . ϵ ϵ ϵ x1 ϵ ϵ ϵ ϵ x2 . . . ϵ ϵ ϵ x1 ϵ ϵ ϵ ϵ x2 Figure 3: Expanding Attention: (left) Graphical description of the toy problem, x2 and y are generated from x1 which is shuffled with a (exponentially decaying) recency bias. (centre) Comparison of the geometric prior, with different shades of red representing the iterative refinements during inference, used by Expanding and uniform prior used by Self Attention. (right) The relative number of operations used by Expanding Attention is beneficial when either the recency bias (1/p) or the number of feature dimensions (d) is large, training curves (overlaid) across each of these settings remained roughly equivalent. and a conjugate hyper-prior p(q) Beta(α, β) we derive a mechanism that dynamically scales depending on input. We use a (truncated) mean-field variational inference procedure [39] to iteratively approximate p(ϕ | x) using the updates: 1. qt βt αt+βt , 2. pt = p(ϕ | x, qt), 3. αt+1 αt + 1, βt+1 βt + P > k the time complexity of expanding attention should be favourable. Task Setup Input and target sequence are generated similarly to above (without x0). Here x1 is moved away from x2 according to a draw from a geometric distribution, Figure 3. 6 Discussion 6.1 The Full Transformer Block Transformer attention is typically combined with residual connections and a feedforward network, both of which have been shown important in preventing token collapse . Here we briefly touch upon how these features might relate to the framework presented here. Feedforward layer, it has previously been noticed the feedforward component can also be understood as a key-value memory where the memories are stored as persistent weights [40, 41]. This is due to the observation ff(x) = W2σ(W1x) is equivalent to attention when the non-linearity σ is a softmax, although a Re LU is typically used. We speculate the framework presented here could be extended explain this discrepancy, intuitively the Re LU relates to an edge prior that fully factorises into binary variables. Residual connections have been shown to encourage iterative inference [42]. This raises the possibility transformer attention, rather than having an arbitrary transformation v as presented in Sec.3, is in fact approximately implementing the iterative inference of Sec.4 through a form of iterative amortised inference [43]. The view that the transformer is performing iterative refinement is additionally supported by empirical studies of early-decoding [44]. Temperature and positional encodings, both positional encodings and the temperature scaling can be seen as adjustments to the prior edge probability. In the case of relative positional encodings, by breaking the permutation invariance of the prior (A.8.2). While the temperature may be understood in terms of tempered (or generalised) Bayesian inference [45], adjusting the strength of the prior relative to the likelihood. 6.2 Limitations The connection to structural inference presented here is limited to the attention computation of a single transformer head, an interesting future direction would be to investigate whether multiple layers and multiple heads typically used in a transformer can also be interpreted within this framework. Additionally, the extension to iterative inference employed a crude approximation to the variational free energy, arguably destroying the favourable properties of Bayesian methods. Suggesting the possibility of creating iterative attention mechanisms with alternative inference schemes, possibly producing more robust mechanisms. 6.3 Conclusion In this paper, we presented a probabilistic description of the attention mechanism, formulating attention as structural inference within a probabilistic model. This approach builds upon previous research that connects cross attention to inference in a Gaussian Mixture Model. By considering the discrete inference step in a Gaussian Mixture Model as inference on marginalised structural variables, we bridge the gap with alignment-focused descriptions. This framework naturally extends to self-attention, graph attention, and iterative mechanisms, such as Hopfield Networks. We hope this work will contribute to a more unified understanding of the functional advantages and disadvantages brought by Transformers. Furthermore, we argue that viewing Transformers from a structural inference perspective provides different insights into their central mechanism. Typically, optimising structure is considered a learning problem, changing on a relatively slow timescale compared to inference. However, understanding Transformers as fast structural inference suggests that their remarkable success stems from their ability to change effective connectivity on the same timescale as inference. This general idea can potentially be applied to various architectures and systems. For instance, Transformers employ relatively simple switches in connectivity compared to the complex dynamics observed in the brain [46]. Exploring inference over more intricate structural distributions, such as connectivity motifs or modules in network architecture, could offer artificial systems even more flexible control of resources. Acknowledgements This work was supported by The Leverhulme Trust through the be.AI Doctoral Scholarship Programme in biomimetic embodied AI. Additional thanks to Alec Tschantz, Tomasso Salvatori, Miguel Aguilera and Tomasz Korbak for their invaluable feedback and discussions. [1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention Is All You Need, December 2017. [2] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam Mc Candlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language Models are Few-Shot Learners. In Advances in Neural Information Processing Systems, volume 33, pages 1877 1901. Curran Associates, Inc., 2020. [3] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, June 2021. [4] Wenhui Wang, Hangbo Bao, Li Dong, Johan Bjorck, Zhiliang Peng, Qiang Liu, Kriti Aggarwal, Owais Khan Mohammed, Saksham Singhal, Subhojit Som, and Furu Wei. Image as a Foreign Language: BEi T Pretraining for All Vision and Vision-Language Tasks, August 2022. [5] Yifan Chen, Qi Zeng, Heng Ji, and Yun Yang. Skyformer: Remodel Self-Attention with Gaussian Kernel and Nystr\"om Method, October 2021. [6] Yao-Hung Hubert Tsai, Shaojie Bai, Makoto Yamada, Louis-Philippe Morency, and Ruslan Salakhutdinov. Transformer Dissection: A Unified Understanding of Transformer s Attention via the Lens of Kernel, November 2019. [7] Xing Han, Tongzheng Ren, Tan Minh Nguyen, Khai Nguyen, Joydeep Ghosh, and Nhat Ho. Robustify Transformers with Robust Kernel Density Estimation, October 2022. [8] Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Thomas Adler, Lukas Gruber, Markus Holzleitner, Milena Pavlovi c, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. Hopfield Networks is All You Need, April 2021. [9] Javier R. Movellan and Prasad Gabbur. Probabilistic Transformers, November 2020. [10] Prasad Gabbur, Manjot Bilkhu, and Javier Movellan. Probabilistic Attention for Interactive Segmentation, July 2021. [11] Xia Li, Zhisheng Zhong, Jianlong Wu, Yibo Yang, Zhouchen Lin, and Hong Liu. Expectation-Maximization Attention Networks for Semantic Segmentation. In 2019 IEEE/CVF International Conference on Computer Vision (ICCV), pages 9166 9175, Seoul, Korea (South), October 2019. IEEE. ISBN 978-1-72814-803-8. doi: 10.1109/ICCV.2019.00926. [12] Alexander Shim. A Probabilistic Interpretation of Transformers, April 2022. [13] Tam Minh Nguyen, Tan Minh Nguyen, Dung D. D. Le, Duy Khuong Nguyen, Viet-Anh Tran, Richard Baraniuk, Nhat Ho, and Stanley Osher. Improving Transformers with Probabilistic Attention Keys. In Proceedings of the 39th International Conference on Machine Learning, pages 16595 16621. PMLR, June 2022. [14] Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent, December 2022. [15] Shivam Garg, Dimitris Tsipras, Percy Liang, and Gregory Valiant. What Can Transformers Learn In Context? A Case Study of Simple Function Classes. [16] Yoon Kim, Carl Denton, Luong Hoang, and Alexander M. Rush. Structured Attention Networks, February 2017. [17] Francesco Locatello, Dirk Weissenborn, Thomas Unterthiner, Aravindh Mahendran, Georg Heigold, Jakob Uszkoreit, Alexey Dosovitskiy, and Thomas Kipf. Object-Centric Learning with Slot Attention, October 2020. [18] Gautam Singh, Yeongbin Kim, and Sungjin Ahn. Neural Block-Slot Representations, November 2022. [19] Helen C. Barron, Ryszard Auksztulewicz, and Karl Friston. Prediction and memory: A predictive coding account. Progress in Neurobiology, 192:101821, September 2020. ISSN 03010082. doi: 10.1016/j. pneurobio.2020.101821. [20] Yuntian Deng, Yoon Kim, Justin Chiu, Demi Guo, and Alexander Rush. Latent Alignment and Variational Attention. In Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018. [21] Xinjie Fan, Shujian Zhang, Bo Chen, and Mingyuan Zhou. Bayesian Attention Modules. In Advances in Neural Information Processing Systems, volume 33, pages 16362 16376. Curran Associates, Inc., 2020. [22] Shiv Shankar, Siddhant Garg, and Sunita Sarawagi. Surprisingly Easy Hard-Attention for Sequence to Sequence Learning. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pages 640 645, Brussels, Belgium, October 2018. Association for Computational Linguistics. doi: 10.18653/v1/D18-1065. [23] Louis Annabi, Alexandre Pitti, and Mathias Quoy. On the Relationship Between Variational Inference and Auto-Associative Memory, October 2022. [24] Jordan Frecon, Gilles Gasso, Massimiliano Pontil, and Saverio Salzo. Bregman Neural Networks. In Proceedings of the 39th International Conference on Machine Learning, pages 6779 6792. PMLR, June 2022. [25] Yongyi Yang, Zengfeng Huang, and David Wipf. Transformers from an Optimization Perspective, May 2022. [26] Martin J. Wainwright and Michael I. Jordan. Graphical Models, Exponential Families, and Variational Inference. Foundations and Trends in Machine Learning, 1(1 2):1 305, November 2008. ISSN 19358237, 1935-8245. doi: 10.1561/2200000001. [27] Yee Teh, David Newman, and Max Welling. A Collapsed Variational Bayesian Inference Algorithm for Latent Dirichlet Allocation. In Advances in Neural Information Processing Systems, volume 19. MIT Press, 2006. [28] Alan L Yuille and Anand Rangarajan. The Concave-Convex Procedure (CCCP). In Advances in Neural Information Processing Systems, volume 14. MIT Press, 2001. [29] Rajesh P. N. Rao and Dana H. Ballard. Predictive coding in the visual cortex: A functional interpretation of some extra-classical receptive-field effects. Nature Neuroscience, 2(1):79 87, January 1999. ISSN 1546-1726. doi: 10.1038/4580. [30] Karl Friston and Stefan Kiebel. Predictive coding under the free-energy principle. Philosophical Transactions of the Royal Society B: Biological Sciences, 364(1521):1211 1221, May 2009. ISSN 0962-8436. doi: 10.1098/rstb.2008.0300. [31] Christopher L. Buckley, Chang Sub Kim, Simon Mc Gregor, and Anil K. Seth. The free energy principle for action and perception: A mathematical review. Journal of Mathematical Psychology, 81:55 79, December 2017. ISSN 0022-2496. doi: 10.1016/j.jmp.2017.09.004. [32] Beren Millidge, Yuhang Song, Tommaso Salvatori, Thomas Lukasiewicz, and Rafal Bogacz. A Theoretical Framework for Inference and Learning in Predictive Coding Networks, August 2022. [33] John H. Reynolds and David J. Heeger. The Normalization Model of Attention. Neuron, 61(2):168 185, January 2009. ISSN 08966273. doi: 10.1016/j.neuron.2009.01.002. [34] Matteo Carandini and David J. Heeger. Normalization as a canonical neural computation. Nature Reviews Neuroscience, 13(1):51 62, January 2012. ISSN 1471-0048. doi: 10.1038/nrn3136. [35] Grace W. Lindsay. Attention in Psychology, Neuroscience, and Machine Learning. Frontiers in Computational Neuroscience, 14, 2020. ISSN 1662-5188. [36] Andy Clark. The many faces of precision (Replies to commentaries on Whatever next? Neural prediction, situated agents, and the future of cognitive science ). Frontiers in Psychology, 4, 2013. ISSN 1664-1078. [37] Harriet Feldman and Karl Friston. Attention, Uncertainty, and Free-Energy. Frontiers in Human Neuroscience, 4, 2010. ISSN 1662-5161. [38] M. Berk Mirza, Rick A. Adams, Karl Friston, and Thomas Parr. Introducing a Bayesian model of selective attention based on active inference. Scientific Reports, 9(1):13915, September 2019. ISSN 2045-2322. doi: 10.1038/s41598-019-50138-8. [39] O. Zobay. Mean field inference for the Dirichlet process mixture model. Electronic Journal of Statistics, 3 (none):507 545, January 2009. ISSN 1935-7524, 1935-7524. doi: 10.1214/08-EJS339. [40] Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. Augmenting Self-attention with Persistent Memory, July 2019. [41] Mor Geva, Roei Schuster, Jonathan Berant, and Omer Levy. Transformer Feed-Forward Layers Are Key-Value Memories, September 2021. [42] Stanisław Jastrz ebski, Devansh Arpit, Nicolas Ballas, Vikas Verma, Tong Che, and Yoshua Bengio. Residual Connections Encourage Iterative Inference, March 2018. [43] Joe Marino, Yisong Yue, and Stephan Mandt. Iterative Amortized Inference. In Proceedings of the 35th International Conference on Machine Learning, pages 3403 3412. PMLR, July 2018. [44] Nora Belrose, Zach Furman, Logan Smith, Danny Halawi, Igor Ostrovsky, Lev Mc Kinney, Stella Biderman, and Jacob Steinhardt. Eliciting Latent Predictions from Transformers with the Tuned Lens, March 2023. [45] Jeremias Knoblauch, Jack Jewson, and Theodoros Damoulas. Generalized Variational Inference: Three arguments for deriving new Posteriors, December 2019. [46] Emmanuelle Tognoli and J. A. Scott Kelso. The Metastable Brain. Neuron, 81(1):35 48, January 2014. ISSN 0896-6273. doi: 10.1016/j.neuron.2013.12.022. Table 1: Different attention modules Name Graph (G) Prior (p(ϕ)) Potentials (ψ) Value v(x, ϕ) Cross Attention Key nodes K, query nodes Q Uniform x T i W T QWKxj V xj Self Attention K = Q, directed edges Uniform x T i W T QWKxj V xj Graph Attention, Sparse Attention K = Q, directed edges Uniform (restricted) x T i W T QWKxj V xj Relative Positional Encodings K = Q, directed edges Categorical x T i W T QWKxj V xj Absolute Positional Encodings K = Q Uniform xi T W T QWK xj xi = xi + ei V xj Classification Layer NN output fθ(X), classes y Uniform fθ(X)T i yj yj MCHN Observed nodes X, latent nodes Z Uniform (observed) z T i W T QWKxj F Slot Attention Observed nodes X latent nodes Z Uniform (latent) z T i W T QWKxj F Block-Slot Attention Observed nodes X, latent nodes Z, memory nodes M Uniform (latent) z T i W T QWKxj, m T k W T QWKzi PCN Observed nodes X, multiple layers of latent nodes {Z(l)}l L Uniform (latent) z T i W T QWKxj F Multihop Attention K = Q, directed edges Uniform x T i W T QWKxj V xj + V xk Expanding Attention K = Q, directed edges Geometric x Beta x T i W T QWKxj V xj Here we include some more detailed derivations of claims made in the paper, and list the hyperparameters used for the experiments. 7.1 Iterative Attention In this section we provide a more detailed treatment of the Laplace approximation, and provide proper justification for invoking the CCCP. For both, the following lemma is useful: Lemma 7.1. The function ln p(x) = ln P ϕ p(x, ϕ) = ln P ϕ exp Eϕ(x) has derivatives (i) x ln p(x) = Eϕ|x[ x Eϕ] and (ii) 2 x2 ln p(x) = V arϕ|x[ x Eϕ] + Eϕ|x[ 2 Proof. Let E = (Eϕ) the vector of possible energies, and p = (pϕ) = (p(ϕ | x))ϕ the vector of conditional probabilities. Consider ln p(ϕ | x) written in canonical form, ln p(ϕ | x) = Eϕ(x), 1ϕ A[Eϕ(x)] + h(ϕ) Where A[E(x)] = ln Z(E) is the cumulant generating function. By well known properties of the cumulant: A Ei = p(ϕ = i | x) = pi. Hence by the chain rule for partial derivatives, A x Eϕ, which is (i). To find the second derivative we apply again the chain-rule d dtf(g(t)) = f (g(t))g (t)2 + f (g(t))g (t). Again by properties of the cumulant 2A Ei Ej = Cov(1i, 1j) = [diag(p) p T p]i,j = Vi,j. Hence the second derivative is Second order Laplace Approximation With these derivatives in hand we can calculate the second order laplace approximation of the free energy F = Eq[ln qλ(z) ln p(x, z)]. F Eq[ln p(µ, x) + z ln p(µ, x)T (z µ) + (z µ)T 2 z2 ln p(µ, x)(z µ)] + H[q] ln p(µ, x) + tr(Σ 1 q V arϕ|µ,x[ z Eϕ]) + tr(Σ 1 q Eϕ|µ,x[ 2 z2 Eϕ]) + 1 2log | Σq | +C We can see optimising the first order variational parameter in this approximation is independent of Σq, hence we can first find µ and the fill in our uncertainty Σq = 2 z2 ln p(µ , x) = V arϕ|µ,x[ z Eϕ] + Eϕ|µ,x[ 2 z2 Eϕ]. Finding this uncertainty can be costly in the general case where the hessian of E is not analytically available. As alluded to in the paper, iterative attention mechanisms can also be viewed as an alternating maximisation procedure, which may provide a route to more general inference schemes: As Alternating Minimisation Collapsed Inference can also be seen as co-ordinate wise variational inference [27]. Consider the family of distributions Q = {q(z; λ)q(ϕ | z)}, where q(z; λ) is parameterised, however q(ϕ) is unconstrained. F = min q Q Eq[ln q(z, ϕ) ln p(x, z, ϕ)] = min q Q Eq(z)[Eq(ϕ)[ln q(ϕ) ln p(x, ϕ | z)] + ln q(z) ln p(z)] The inner expectation is maximised for q(ϕ) = p(ϕ | x, z) and the inner expectation evaluates to ln p(x | z) which recovers the marginalised objective min q Q Eq(z)[q(z) ln X ϕ p(x, z, ϕ)] This motivates an alternate derivation of iterative attention as structural inference which is less reliant on the Laplace approximation; Consider optimising over the variational family Q = {q(z; λ)q(ϕ)} coordinate wise: ln qt+1(ϕ) = Eqt(z;λt)[ln p(ϕ | x, z)] + C λt+1 = arg min λ Eqt(ϕ)[Eq(z;λ)[ln q(z) ln p(x, z | ϕ)] = arg min λ Eqt(ϕ)[Fϕ] In the case of quadratic potentials, qt+1(ϕ) = p(ϕ | x, λt), hence the combined update step can be written arg min λ Ep(ϕ|x,λt)[Fϕ(λ)] Each step necessarily reduces the free energy of the mean-field approximation, so this process converges. This derivation is independent of which approximation or estimation is used to minimise the typical variational free energy. 7.1.1 Convexity details for the CCCP Given a pairwise p MRF with quadratic potentials ψ(xi) = 1 2x2 i and the edge potentials have the form ψ(xi, xj) = xi Wxj and W p.s.d., s.t. ln p(x, ϕ) = 1 v G x2 v + ln P ϕ exp gϕ(x), where gϕ(x) = P e ϕ ψe. We need the following lemma to apply the CCCP: Lemma 7.2. ln P ϕ exp gϕ(x) is convex in x. Proof. We reapply Lemma.7.1, with Eϕ = gϕ(x), hence 2 x2 ln P ϕ exp gϕ(x) = V arϕ|x[ x2 gϕ]. The first matrix is a variance, so p.s.d. The second term Eϕ|x[P e ϕ 2 x2 ψe] is a convex sum of p.s.d matrices. Hence both terms are p.s.d, implying ln P ϕ exp gϕ(x) is indeed convex. 7.2 PCN Detailed Derivation Here we go through the derivations for the equations presented in section 4.4. PCNs typically assume a hierarchical model with gaussian residuals: z0 N( ˆµ0, Σ0) zi+1 | zi N(fi(zi; θi), Σi) y | z N N(f N(z N; θN), ΣN) Under these conditions, a delta approximation of the variational free energy is given by: F[p, q] = Eq(z;µ)[ ln p(y, z)] + H[q] l=0 Σ 1 l ϵ2 l Where ϵl = (µl+1 fl(µl; θl))2. The inference phase involves adjusting the parameters, µ in the direction of the gradient of F, which for a given layer is: F µl = Σ 1 l 1ϵl 1 Σ 1 l ϵlf (µl) (10) Here, for ease of comparison, we consider the case where the link functions are linear, fi( ) = Wi( ) and further the precision matrices are diagonal Σ 1 i = diag(ki). Under these conditions we can write the derivative component-wise as sums of errors over incoming and outgoing edges : Where ϕ , ϕ+ represent the set of incoming and outgoing edges respectively, and we redefine ϵϕ = (µi µjwij) for an edge ϕ = (zi, zj) and kϕ = K(zj) the precision associated with the node at the terminus of ϕ. Now if we instead assume a uniform prior over incoming edges, or concretely; z0 N( ˆµ0, Σ0) ϕi l Uniform({(zi l+1, z0 l ), (zi l+1, z1 l ), ...} zi l+1 | zl, ϕi l N(wij l zϕi l l , 1/ki l) y | z N N(f N(z N; θN), ΣN) The system becomes a p MRF with edge potentials given by the prediction errors, recall applying Eq.4: F µj = X ϕi softmax(fi(x, µ, ϕi)) fi Here for a node in a given layer, it participates in one Φj l 1 and all the Φk l+1 from the layer above, where every fi(x, µ, ϕi) here is a squared prediction error corresponding to the given edge eij l = kij l (zi l wij l zj l 1)2, hence: softmaxi( (ϵij l 1)2)ϵij l 1kj softmaxi ( (ϵi k l )2)ϵi k l wi k l 1(i = j) softmaxi( (ϵij l 1)2)ϵij l 1 k [l] softmaxi ( (ϵi k l )2)ϵi k l wi k l Here incoming signals (nodes i) compete through the softmax, whilst the outgoing signal competes with other outgoing signals from nodes (nodes i ) in the same layer for representation in the next layer (nodes k), see block-slot attention diagram for intuition. By abuse of notation (reindexing edges as ϕ) F µi = X ϕ softmax( ϵϕ 2)kϕϵϕ + X ϕ+ softmax( ϵϕ 2)kϕϵϕwϕ While we derived these equations for individual units to draw an easy comparison to standard Predictive Coding, we note it is likely more useful to consider blocks of units competing with each other for representation, similar to multidimensional token representations in typical attention mechanisms. We also briefly note here, the Hammersley Clifford theorem indicates a deeper duality between attention as mediated by precision matrices and as structural inference. 7.3 New Designs Multihop Derivation Ey|x,ϕ[yi] in transformer attention, a linear transformation is applied to the most likely neighbour, xj, of xi. A natural extension is to include a two-hop neighbourhood, additionally using the most likely neighbour xk of xj. Formally, the value function v no longer neatly distributes over the partition Φi, however the attention mechanism then takes a different form: Ep(ϕj|ϕi)p(ϕi|x)[V (xϕi + xϕj)] = (Pϕ + P 2 ϕ)V X. Where we use ϕj(i) = ϕj to denote the edge set of the node at the end of ϕi. To see this note: Ep(ϕ|x)[V (xϕi + xϕj)] = X k p(ϕk | x)V (xϕi + xϕj) k p(ϕk | x)V (xϕi + xϕj) k p(ϕk | x)V xϕi + X k p(ϕk | x)V xϕj by independence properties ϕi p(ϕi | x)V xϕi + X ϕi,ϕj p(ϕi | x)p(ϕj | x)V xϕj Denoting the typical attention matrix, P, where pij = p(ϕi = [j] | x) j pjkpij V xk + X = (Pϕ + P 2 ϕ)V X Expanding Derivation As in the main text, let p(ϕ | q) Geo(q) and p(q) Beta(α, β), such that we have the full model p(x, ϕ, q; α, β) = p(x | ϕ)p(ϕ | q)p(q; α, β). In order to find p(ϕ | x) we employ a truncated Mean Field Variational Bayes [39], assuming a factorisation pt(ϕ, q) = pt(ϕ)pt(q), and using the updates: ln pt+1(ϕ) = Ept(q)[ln p(x | ϕ) + ln p(ϕ | q)] + C1 ln pt+1(q) = Ept(ϕ)[ln p(ϕ | q) + ln p(q; α, β)] + C2 By conjugacy the second equation simplifies to a simple update of the beta distribution = pt+1(q) = Beta(αt+1, βt+1) αt+1 = αt + 1 βt+1 = βt + Ept(ϕ)[ϕ] While the second update can be seen as calculating the posterior given qt = Ept(q)[q], ln pt+1(ϕ) = ln p(x | ϕ) + Ept(q)[ln p(ϕ | q)] + C2 = ln p(x | ϕ) + ϕEpt(q)[ln q] + C2 = ln p(ϕ | x, qt) Finally, we use a truncation to approximate the infinite sum Ept(ϕ)[ϕ] = P k pt(ϕ = k)k P