# monarch_mixer_a_simple_subquadratic_gemmbased_architecture__f85d08c1.pdf MONARCH MIXER: A Simple Sub-Quadratic GEMM-Based Architecture Daniel Y. Fu1, Simran Arora ,1, Jessica Grogan ,2, Isys Johnson ,2, Sabri Eyuboglu ,1, Armin W. Thomas ,3, Benjamin Spector1, Michael Poli1, Atri Rudra2, Christopher R e1 Equal Contribution. 1Department of Computer Science, Stanford University. 2Department of Computer Science and Engineering, University at Buffalo, SUNY. 3Department of Psychology, Stanford University. danfu@cs.stanford.edu, simarora@stanford.edu, {jrgrogan,isysjohn}@buffalo.edu, {eyuboglu,athms,bfs,poli}@stanford.edu, atri@buffalo.edu, chrismre@cs.stanford.edu Machine learning models are increasingly being scaled in both sequence length and model dimension to reach longer contexts and better performance. However, existing architectures such as Transformers scale quadratically along both these axes. We ask: are there performant architectures that can scale sub-quadratically along sequence length and model dimension? We introduce MONARCH MIXER (M2), a new architecture that uses the same sub-quadratic primitive along both sequence length and model dimension: Monarch matrices, a simple class of expressive structured matrices that captures many linear transforms, achieves high hardware efficiency on GPUs, and scales sub-quadratically. As a proof of concept, we explore the performance of M2 in three domains: non-causal BERT-style language modeling, Vi T-style image classification, and causal GPT-style language modeling. For non-causal BERT-style modeling, M2 matches BERT-base and BERT-large in downstream GLUE quality with up to 27% fewer parameters, and achieves up to 9.1 higher throughput at sequence length 4K. On Image Net, M2 outperforms Vi T-b by 1% in accuracy, with only half the parameters. Causal GPT-style models introduce a technical challenge: enforcing causality via masking introduces a quadratic bottleneck. To alleviate this bottleneck, we develop a novel theoretical view of Monarch matrices based on multivariate polynomial evaluation and interpolation, which lets us parameterize M2 to be causal while remaining sub-quadratic. Using this parameterization, M2 matches GPT-style Transformers at 360M parameters in pretraining perplexity on The PILE showing for the first time that it may be possible to match Transformer quality without attention or MLPs.1 1 Introduction Machine learning models in natural language processing and computer vision are being stretched to longer sequences and higher-dimensional representations to enable longer context and higher quality, respectively [6, 10, 62, 84]. However, existing architectures exhibit time and space complexities that grow quadratically in sequence length and/or model dimension which limits context length and makes scaling expensive. For example, attention and MLP in Transformers scale quadratically in sequence length and model dimension [15]. In this paper, we explore a natural question: can we find a performant architecture that is sub-quadratic in both sequence length and model dimension? 1Code is available at https://github.com/Hazy Research/m2. 37th Conference on Neural Information Processing Systems (Neur IPS 2023). Order-p Monarch Matrices ( ) i i = 1 def M2_layer(X): # mix sequence Z = M @ (k * (M @ X)) # mix channels Y = M @ σ(M @ Z.T)) return Y Simple Layers Efficient Mixing on Sequence, Dimension Subquadratic: O(p N(p+1)/p) Hardware-Efficient (GEMMs) Expressive (generalizes FFT) Figure 1: Monarch matrices are a simple, expressive, and hardware-efficient class of sub-quadratic structured matrices. MONARCH MIXER (M2) uses Monarch matrices to mix inputs first along the sequence dimension and then along the model dimension. See the Appendix for Py Torch implementation of an M2 layer. In our exploration, we seek a sub-quadratic primitive for both the sequence length and model dimension. Our framing takes inspiration from work such as MLP-mixer [73] and Conv Mixer [74], which observed that many machine learning models operate by repeatedly mixing information along the sequence and model dimension axes, and used a single operator for both axes. Finding mixing operators that are expressive, sub-quadratic, and hardware-efficient is challenging. For example, the MLPs in MLP-mixer and convolutions in Conv Mixer are expressive, but they both scale quadratically in their input dimension [73, 74]. Several recent studies have proposed sub-quadratic sequence mixing with long convolutions or state space models [27, 64, 77] both computed using the FFT but these models have poor FLOP utilization (3-5% [28]) and maintain quadratic scaling in model dimension. Meanwhile, there has been promising work in sparsifying dense MLP layers without losing quality, but some of the models can actually be slower than their dense counterparts, due to low hardware utilization [7, 8, 14, 26, 35]. We turn to an expressive class of sub-quadratic structured matrices called Monarch matrices [14] (Figure 1 left) to propose MONARCH MIXER (M2). Monarch matrices are a family of structured matrices that generalize the fast Fourier transform (FFT) and have been shown to capture a wide class of linear transforms including Hadamard transforms, Toeplitz matrices [32], AFDF matrices [57], and convolutions. They are parameterized as the products of block-diagonal matrices, called monarch factors, interleaved with permutation. Their compute scales sub-quadratically: setting the number of factors to p results in computational complexity of O(p N (p+1)/p) in input length N, allowing the complexity to interpolate between O(N log N) at p = log N and O(N 3/2) at p = 2.2 M2 uses Monarch matrices to mix information along the sequence and model dimension axes. It is both simple to implement and hardware-efficient: the block-diagonal Monarch factors can be computed efficiently on modern hardware using GEMMs (generalized matrix multiply algorithms). Our proof-of-concept implementation of an M2 layer, written in less than 40 lines of pure Py Torch (including imports), relies only on matrix multiplication, transpose, reshape, and elementwise products (see pseudocode in Figure 1 middle) and achieves 25.6% FLOP utilization3 for inputs of size 64K on an A100 GPU. On newer architectures such as the RTX 4090, a simple CUDA implementation achieves 41.4% FLOP utilization at the same size. Non-Causal Settings As a first proof of concept of M2, we evaluate how it compares to Transformers in terms of speed and quality in non-causal settings such as BERT-style masked language modeling [21] and Image Net classification. We introduce M2-BERT, which replaces the attention blocks in BERT with bidirectional gated convolutions implemented using Monarch matrices and replaces the dense matrices in the MLP with Monarch matrices. M2-BERT reduces parameter count but maintains quality matching BERT-base and BERT-large in downstream GLUE quality with 27% and 24% fewer parameters, respectively. Sub-quadratic scaling in sequence length enables high throughput at longer sequences up to 9.1 higher throughput at sequence length 4K than 2Monarch matrices were originally [14] parameterized with p = 2, but the general p case is a natural extension. 3For context, the most optimized attention implementations achieve 25% FLOP utilization, while unoptimized implementations of attention can have as low as 10% FLOP utilization [15]. Hugging Face BERT, and 3.1 higher throughput at sequence length 8K than BERT optimized with Flash Attention [15]. For image classification, we adapt Hyena Vi T-b [64], an attention-free vision transformer based on gated convolutions. We replace the convolution operation with M2 primitives and replace the MLP layers with an M2 block as well. These changes reduce the parameter count compared to a Vi T-b [22] model with the same model width and depth by a factor of 2. Surprisingly, despite this parameter reduction, we find that M2 slightly outperforms Vi T-b and Hyena Vi T-b baselines, achieving 1% higher accuracy on Image Net [18]. Causal Settings Causal settings such as GPT-style [65] auto-regressive language modeling present a technical challenge: masking out the upper triangular elements in an attention matrix (or equivalent structure) introduces a quadratic bottleneck. To alleviate this quadratic bottleneck with Monarch matrices, we develop new theory to characterize which parameterizations of Monarch matrices maintain causality. To do so, we take a view of p-order Monarch matrix multiplication as p-variate polynomial evaluation and interpolation (e.g., p = 2 factors corresponds to bivariate polynomials, Figure 2 left). Using this view, we show that the M2 convolution shown in Figure 1 (middle) can be viewed as manipulation of modular polynomial multiplication. This result allows us to develop conditions (Theorem 3) under which M2 is causal. We can use this causal parameterization to outperform GPT-style language models on causal language modeling by 0.2 PPL points on the PILE at model size 360M without using either attention or MLP blocks. Summary Overall, our results present a potential path to building machine learning models with sub-quadratic primitives. We hope our work can serve as a starting point to explore models that are more efficient in both sequence length and model dimension. 2 Preliminaries In this section, we provide some background on the key components behind the cost of operations on GPUs, and then discuss the scaling characteristics of some common primitives used to mix information across the sequence dimension and model dimension in modern machine learning models. GPU Accelerator Cost Model We provide a brief discussion of relevant factors affecting runtime performance of deep learning operations on GPUs. Depending on the balance of computation and memory accesses, operations can be classified as either compute-bound or memory-bound [44]. In compute-bound operations, the time accessing GPU memory is relatively small compared to the time spent doing arithmetic operations. Typical examples are matrix multiply with large inner dimension, and short convolution kernels with a large number of channels. The speed of these operations is determined by the FLOP/s available on compute units, and the number of FLOPs necessary to complete the operation. In our paper, we exploit fast matrix multiply units such as tensor cores. On the A100, tensor cores can achieve 312 TFLOP/s in half-precision matrix multiply operations, while non-matrix multiply operations are limited to 19 TFLOP/s [59]. This trend began with tensor cores in the V100 [58], and is continuing into the next-generation H100 chips [60]. Table 1: FLOP cost and utilization of various mixer layers, input dimension 64K on an RTX 4090. Layer FLOP Cost Util MLP N 2 95.5% Flash Attn N 2 24.0% FFT N log N 3.0% M2 Conv N 3/2 41.4% In memory-bound operations, the time taken by the operation is determined by the number of memory accesses, while time spent in computation is much smaller. Examples include most elementwise operations (e.g., activation, dropout) and reductions (e.g., sum, softmax, batch norm, layer norm). The runtime of memory-bound operations is determined by the memory bandwidth of different layers of the memory hierarchy. GPU memory is large but relatively slow up to 80 GB on A100, but with bandwidth of 2 TB/s [59]. Higher levels of the memory hierarchy such as caches are much smaller (20 MB) but an order of magnitude faster (19 TB/s). Common Mixer Primitives To help contextualize our work, we provide scaling and hardware utilization characteristics for a few common operations that are used to mix information in machine learning models, summarized in Table 1. Transformers [75] use attention to mix information across the sequence dimension, and MLP blocks to mix information across the model dimension. Both of these blocks scale quadratically in input length. MLP layers are compute-bound, so they have high FLOP utilization out of the box. Attention blocks are memory-bound, so even the most optimized implementations such as FLASHATTENTION [15] have relatively lower FLOP utilization. Recent work has made progress towards attention-free models by replacing attention layers with long convolution layers, interleaved with elementwise gating [27, 28, 36, 54, 64, 67 69]. These layers are computed using FFT operations using the FFT convolution theorem: y = K X = FFT 1(FFT(X) FFT(K)). While the FFT scales asymptotically well in O(N log N), it is often memory-bound and thus has low FLOP utilization. In our work, we aim to construct a mixer that has both sub-quadratic scaling and high FLOP utilization. 3 MONARCH MIXER In this section, we recall Monarch matrices, introduce how M2 uses Monarch matrices to mix along the sequence and model dimensions, and benchmark a M2 convolution in terms of hardware utilization. 3.1 Monarch Matrices Monarch matrices [14] are a sub-quadratic class of structured matrices that are hardware-efficient and expressive. They can represent many linear transforms, including convolutions, Toeplitz-like transforms, low-displacement rank transforms, and orthogonal polynomials. Directly implementing these different structured transforms on GPUs as dense matrices can be inefficient. In contrast, their Monarch decompositions can be computed by interleaving matrix multiplications with tensor permutations. A Monarch matrix M RN N of order p is defined by the following: where each Pi is related to the base p N variant of the bit-reversal permutation, and Bi is a block-diagonal matrix with block size b. Setting b = p N achieves sub-quadratic compute cost. For example, for p = 2, b = N, Monarch matrices require O(N 3/2) compute in sequence length N. In this paper, we use Monarch matrices to construct architectures that are sub-quadratic in both sequence length N and model dimension d. We will often parameterize order-2 Monarch matrices, written as M = PLPRP, where L and R are block-diagonal matrices (for left and right ), and P = P2 = P1 = P0 is a permutation that reshapes the input to 2D, transposes it, and flattens it to 1D. A common case is to set L = R = (I N), where F N DFT matrix, and is the Kronecker product. 3.2 MONARCH MIXER Architecture We describe how MONARCH MIXER uses Monarch matrices and elementwise operations to construct sub-quadratic architectures (Figure 1 middle). We take a mixer view of model architectures, where each layer is a sequence of mixing operations across the sequence and the model dimension axes. Each layer takes as input a sequence of embeddings X RN d, and outputs a sequence Y RN d, where N is the sequence length, and d is the model dimension. For simplicity, we show the order-2 case here, though we can use higher-order blocks to scale to longer sequences and larger model dimensions. Let M1, M2 RN N and M3, M4 Rd d be order-2 Monarch matrices, let K1 RN d, let σ be an optional point-wise non-linearity (e.g. Re LU), and let be elementwise multiplication. M2 Table 2: FLOP cost and utilization of M2 compared to dense MLP at different input sizes N, with block size N, on an A100 and RTX 4090. N 4K 16K 64K 256K Dense Matmul TFLOP Cost 0.025 0.412 6.60 106.0 M2 TFLOP Cost 0.002 0.013 0.103 0.824 Dense FLOP Utilization (A100) 63.0% 78.0% 80.0% OOM M2 FLOP Utilization (A100) 4.78% 12.7% 25.6% 42.8% Wall-Clock Speedup (A100) 1.2 5.1 20.6 >55.0 Dense FLOP Utilization (4090) 74.6% 96.7% 98.0% OOM M2 FLOP Utilization (4090) 11.1% 32.1% 41.4% 53.7% Wall-Clock Speedup (4090) 2.2 10.5 27.0 >69.1 uses Monarch matrices to construct expressive architectures. For example, a convolutional block with a sparse MLP can be expressed as follows: 1. Mix along sequence axis: X = M2(K1 M1X) (2) 2. Mix along embedding axis: Y = M4σ(M3 X ) (3) When M1 is set to the DFT and M2 is set to the inverse DFT, Equation 2 exactly corresponds to a convolution with kernel K1 parameterized in frequency space. Equation 3 corresponds to an MLP with the dense matrices replaced by Monarch matrices. More expressive layers are also easily expressible; for example, replacing Equation 2 with V M2(K1 M1(Q K)), where Q, K, V are linear projections of X, reproduces a gated convolution block, as in [27, 28, 64]. The basic M2 layer is simple to implement; pseudocode is shown in Figure 1 (middle), and the Appendix gives an efficient implementation of M2 in under 40 lines of pure Py Torch (including imports). The convolution case with Monarch matrices fixed to DFT and inverse DFT matrices also admits implementations based on FFT algorithms [11]. 3.3 Architecture Benchmarks We benchmark the efficiency of the M(K MX) convolution operator (Equation 2) implemented in a simple CUDA kernel (calling standard cu BLAS sub-routines [61]), as the dimension N increases. Equation 3 scales similarly, as dimension d increases. We keep the block size b fixed to Table 2 shows the FLOP cost and utilization of a M2 operator as a function of the input size on an A100 as well as on an RTX 4090. On the A100, the operator is more dominated by the data movement costs of the permutation operations (see the Appendix for a roofline analysis). For longer inputs, the sub-quadratic scaling allows MONARCH MIXER to outperform dense matrix multiplication. On the RTX 4090, which has a larger and faster L2 cache than the A100, we can manually optimize an implementation to amortize data movement costs. 4 Theoretical Analysis: M2 as Polynomial Multiplication In this section, we develop theory to make the M2 layer causal in the input X e.g., ensure that an output Yi of the M2 should only depend on X1, ..., Xi. Our approach involves interpreting Monarch matrix multiplication as multivariate polynomial evaluation and interpolation. We then show that an M2 convolution is equivalent to modular polynomial manipulation in a univariate basis. The challenge is controlling the degrees of the resulting univariate polynomials, to prevent underflow under modular multiplication (see Figure 2 for an overview). Our key result is deriving sufficient conditions on the degrees of the bivariate polynomials defining the Monarch factors to prevent such underflow. We focus on the bivariate case (order p = 2) in the body, and give the general multivariate case in the Appendix. We present proof sketches in the main body, and leave proofs and additional results for the Appendix. P P P M-1(Mu Mk) Causal Parameterization ℓ(X, Y) r(Y) Conditions on Bivariate Polynomial Degrees Univariate Multiplication deg(f), deg(g) < N / 2 Causal Map f(Z) g(Z) mod ZN Figure 2: Monarch multiplication can be interpreted as polynomial evaluation and interpolation. We derive sufficient conditions on the polynomial formulation of Monarch matrices for M2 to be causal. Monarch Multiplication as Polynomial Evaluation First, we show that order-2 Monarch matrixvector multiplication M u is equivalent to bivariate polynomial evaluation. Fix a Monarch matrix M RN N = PLPRP, for two block-diagonal matrices L and R with blocks of size b = N. We can interpret Monarch matrices as bivariate polynomial evaluation by setting A = {ω0, . . . , ωb 1} as a set of evaluation points (e.g., the bth roots of unity), and letting {ℓ0(X, Y ), . . . , ℓb 1(X, Y )}, {r0(Y ), . . . , r N 1(Y )} be sets of basis polynomials with individual degrees of X, Y being < N. The values of {ℓ0(X, Y ), . . . , ℓb 1(X, Y )} evaluated on A2 determine the entries of L, and the values of {r0(Y ), . . . , r N 1(Y )} evaluated on A determine the entries of R. We give the mapping from ℓ, r, and A to L and R in the Appendix. Then, matrix-vector multiplication between M and a vector u is equivalent to polynomial evaluation of the basis functions ℓ, r on the evaluation points A2: Theorem 1. Let m(j) = j mod N. For any vector u RN, Mu is a bivariate polynomial u(X, Y ) evaluated at A2, with u(X, Y ) = PN 1 j=0 ujfj(X, Y ), where fj(X, Y ) = ℓm(j)(X, Y )rj(Y ). Monarch Inverse as Polynomial Interpolation Next, we exploit the fact that Monarch inverse multiplication M 1 u is equivalent to polynomial interpolation in the basis polynomials of M. Theorem 2. Let M0, M1, M2 be Monarch matrices, and let A be the set of N roots of unity. Then, the operation f = M 1 0 ((M1k) (M2u)) . (4) is equivalent to representing the polynomial h(X, Y ) = k(X, Y )u(X, Y ) mod (X in terms of the basis polynomials ℓ, r corresponding to M0, and where k(X, Y ) and u(X, Y ) are the polynomials corresponding to M1k and M2u, respectively. The above follows from Theorem 1 and the fact that Monarch matrix-vector multiplication with an inverse Monarch matrix is equivalent to polynomial interpolation in a given basis. The mod part comes from the fact that A is the set of roots of the polynomial Z Causal Monarch Maps Now, we give a class of Monarch matrices from which we can build a causal map. First, we define a polynomial with minimum degree j: Definition 1. A polynomial of minimum degree j (and maximum degree N 1) is defined as qj(Z) = PN 1 a=j qj[a]Za. To ensure causality, we first convert the bivariate polynomial basis into a univariate basis, and then we expand the degree of the univariate polynomial. The resulting univariate polynomial multiplication is naturally causal (exploiting similar properties as the causal FFT convolution). We use the Kronecker substitution (X Z, Y Z N) to convert the bivariate polynomial basis into a univariate basis: qj(Z) = ℓm(j)(Z)rj Z where m(j) is defined as in Theorem 1. Then, the following class of Monarch matrices (with the conversion to univariate polynomial basis as above) forms a causal map: Theorem 3. Let u, k Rn, where n < N/2. Let m(j) be as in Theorem 1, and k(j) = j j/ Then define the basis polynomials ℓm(j) to have minimum degree m(j), basis polynomials rj to have minimum degree k(j), and all polynomials qj(Z) to have maximum degree < N/2 for all j < N/2 and for N/2 j < N have maximum degree N 1. Let MN be defined by such basis polynomials via (5) where the evaluation points are now the Nth roots of unity. Then, we have that u 7 M 1 N (MN(k, 0N n) MN(u, 0N n)) [0 : n 1] (6) gives a causal map in u. Theorem 3 gives a causal map that can be computed entirely using Monarch matrices enforcing causality with sub-quadratic scaling. The main technical ingredient in proving the above result is that the product qj(Z)qj (Z) can be written as a linear combination of qa(Z) for j + j a < N (this uses the above specified properties on the minimum and maximum degrees of qj(Z)). This in turn implies that the term kj ujqj(Z)qj (Z) only contributes to the coefficients of higher order basis polynomials qa(Z) for a j + j in the product k(Z)u(Z), which is needed for causality. Figure 2 gives an example of restricted polynomials generating a causal map. 5 Experiments Monarch long conv Monarch long conv Sequence Mixer Dimension Mixer Figure 3: M2-BERT uses Monarch matrices to create a bidirectional gated long convolution in the sequence mixer, and uses Monarch matrices to replace the linear layers in the dimension mixer. We compare MONARCH MIXER to Transformers on three tasks where Transformers have been dominant: BERTstyle non-causal masked language modeling, Vi T-style image classification, and GPT-style causal language modeling. In each, we show that we can match Transformers in quality using neither attention nor MLPs. We additionally evaluate wall-clock speedups against strong Transformer baselines in the BERT setting. Additional experiments on speech and alternative architectures are given in Appendix B, and experimental details are given in Appendix C. 5.1 Non-Causal Language Modeling We introduce M2-BERT, an M2-based architecture for non-causal language modeling. M2-BERT acts as a dropin replacement for BERT-style language models [21], which are a workhorse application of the Transformer architecture [1, 39, 40, 45, 48, 49, 52, 56, 85, 89]. We train M2-BERT using masked language modeling over C4 [66] with the bert-base-uncased tokenizer. M2-BERT starts with a Transformer backbone and replaces the attention and MLPs with M2 layers, shown in Figure 3. In the sequence mixer, we replace attention with bidirectional gated convolutions with a residual convolution (Figure 3 left). To recover convolutions, we set the Monarch matrices to DFT and inverse DFT matrices. Following [27, 64], we also add short depthwise convolutions after the projections. In the dimension mixer, we replace the two dense matrices in MLPs with learned block-diagonal matrices (Monarch matrix of order 1, b = 4). We pretrain two M2-BERT-base models, at 80M and 110M, and two M2-BERT-large models, at 260M and 341M. These are equivalent to BERT-base and BERT-large, respectively. Downstream GLUE Scores First, we evaluate M2-BERT models on downstream fine-tuning compared to BERT-base and BERT-large from [20]. We take the pretrained models and fine-tune them on BERT, following the procedure in [38]. Table 3 shows performance for BERT-base equivalent models, and Table 4 shows performance for BERT-large equivalent models. M2-BERT-base can Table 3: Average GLUE Score for M2-BERT-base compared to BERT-base [20], along with change in parameters and GLUE score. Model GLUE Score Params GLUE Score BERT-base (110M) 79.6 -0% +0.0 M2-BERT-base (80M) 79.9 -27% +0.3 M2-BERT-base (110M) 80.9 -0% +1.3 Table 4: Average GLUE Score for M2-BERT-large compared to BERT-large [20], along with change in parameters and GLUE score. Model GLUE Score Params GLUE Score BERT-large (340M) 82.1 -0% +0.0 M2-BERT-large (260M) 82.2 -24% +0.1 M2-BERT-large (341M) 82.8 +0.2% +0.7 match BERT-base in GLUE quality with 27% fewer parameters or outperform BERT-base in quality by 1.3 points when parameter matched. M2-BERT-large matches BERT-large with 24% fewer parameters, and outperforms by 0.7 points when parameter matched. GPU Throughput by Sequence Length Next, we evaluate throughput of M2-BERT models by sequence length, compared to Hugging Face implementations of BERT, as well as optimized implementations of BERT running Flash Attention [15]. Table 5 shows forward throughput for BERTbase equivalent models, and the appendix shows throughput for BERT-large (where the performance trends are similar). Inference times are reported in tokens/ms on an A100-40GB GPU. M2-BERTbase achieves higher throughput than even highly-optimized BERT models, and up to 9.1 faster throughput than a standard Hugging Face implementation at sequence length 4K. CPU Inference Latency Finally, we report CPU inference latency for M2-BERT-base (80M) compared to BERT-base, running direct Py Torch implementations for both. In short sequences, the impacts of data locality still dominate the FLOP reduction, and operations such as filter generation (which are not present in BERT) pay a higher cost. Starting at sequences 1K and longer, M2-BERTbase starts to have speedup over BERT-base, up to 6.5 at sequence length 8K. We believe further optimization and applying IO-aware principles can further improve CPU performance. 5.2 Image Classification To validate that our methods generalize to images as well as language for non-causal modeling, we next evaluate M2 on image classification. We compare M2 to Vi T-style models and recent work, Hyena Vi T-b [64], which uses gated long convolutions to replace the attention layers in Vi T-b. In our work, M2-Vi T builds off Hyena Vi T-b and replaces the long convolutions with the M2 operator in Equation 2 (again setting the Monarch matrices to the DFT and inverse DFT). We replace the MLP blocks in Hyena Vi T-b with block-diagonal matrices, similarly to M2-BERT. Appendix B additionally compares M2 to the Swin-family of architectures [50, 51]. Table 7 shows the performance of MONARCH MIXER against Vi T-b, Hyena Vi T-b, and Vi T-b Monarch (which replaces the MLP blocks of standard Vi T-b with Monarch matrices) on Image Net-1k. MONARCH MIXER outperforms the other models with only half the parameters of the original Vi T-s model. Surprisingly, MONARCH MIXER also outperforms Res Net-152, with fewer parameters even though the latter was explicitly designed for Image Net performance. 5.3 Causal Language Modeling GPT-style causal language modeling is a critical application for Transformers [6, 31, 43]. We introduce M2-GPT, a M2-based architecture for causal language modeling. For the sequence mixer, M2-GPT combines the convolutional filter from Hyena [64], the state-of-the-art attention-free Table 5: Throughput in tokens/ms by context length for M2-BERT-base (80M) compared to BERTbase. Model 512 1024 2048 4096 8192 HF BERT-base (110M) 206.1 130.8 71.3 39.0 OOM Flash Attention BERT-base (110M) 367.4 350.1 257.2 179.1 102.4 M2-BERT-base (80M) 386.3 380.7 378.9 353.9 320.1 M2 Speedup over HF BERT-base (110M) 1.9 2.9 5.2 9.1 Table 6: CPU inference latency in milliseconds with a batch size of 1 at varied input sequence lengths. Measurements averaged over 10 examples on a 48 v CPU, 96 GB RAM instance from the GCP n2-standard-48 series, which runs Intel Cascade Lake processors. This is based on the protocol in [29]. Model 512 1024 2048 4096 8192 BERT-base (110M) 182 389 918 2660 11820 M2-BERT-base (80M) 289 361 651 948 1820 Speedup 0.6 1.1 1.4 2.8 6.5 language model, with parameter sharing across multiple heads from H3 [27]. We use the causal parameterization of Equation 2 to replace the FFT in these architectures, and we remove the MLP layers entirely. The resulting architecture is entirely attentionand MLP-free. We pretrain M2-GPT on the PILE, a standard dataset for causal language modeling. Following prior work [28, 64], we train models at two model sizes, with varying amounts of training data decaying the learning rate appropriately for each experiment. Table 8 shows the results. Even though our model is attentionand MLP-free, it outperforms both Transformers and Hyena in perplexity on pretraining. These results suggest that radically different architectures than Transformers may be performant on causal language modeling. 6 Related Work Long Convolutions Recent work proposes to use long convolution layers as a replacement for the Transformer attention layers in sequence modeling [28, 64, 67 69]. Many of these models rely on the FFT convolution theorem to compute the long convolutions. We build on the insights in many of these architectures in constructing our M2 architectures, and additionally replaces the FFT operations with Monarch matrices. Our work is also related to a rich literature in convolutions in other bases, such as Chebyshev bases [82] or orthogonal polynomial bases [34]. These approaches have analogues in our multivariate analysis; replacing the basis polynomials of the Monarch matrices in MONARCH MIXER may be able to approximate some of these operations. An interesting question for future work would be to study how well our techniques and concerns about causality and hardware utilization translate to these alternative convolution bases. Optimization of deep learning primitives There is a rich history of the optimization of deep learning primitives, as accelerating their performance can yield substantial savings in compute and cost for large models. There are many approaches to speed up these operations, but they usually either reduce data movement or compute. Reducing data movement: In many applications, the major bottleneck is the storage and movement of large amounts of memory. One popular approach to reducing data movement is checkpointing, wherein one stores fewer intermediate results and recomputes the others on-the-fly where they are needed, trading additional compute for memory [46, 78]. Another approach is kernel fusion, wherein algorithms initially described as sequential steps can often be fused in ways that improve their properties. For example, it is generally faster to implement a dot-product through a multiply- Table 7: Accuracy on Image Net-1k. Res Net-152 provided for reference. Model Top-1% Top-5% Description Res Net-152 (60M) 78.6 94.3 Conv Net, MLP Vi T-b (87M) 78.5 93.6 Attention, MLP Vi T-b + Monarch (33M) 78.9 94.2 Attention, MLP-Free Hyena Vi T-b (88M) 78.5 93.6 Attention-Free, MLP M2-Vi T-b (45M) 79.5 94.5 Attention-Free, MLP-Free Table 8: Perplexity on the PILE when trained for different numbers of tokens. Model 5B 10B 15B Description Transformer (125M) 13.3 11.9 11.2 Attention, MLP Hyena (155M) 13.1 11.8 11.1 Attention-Free, MLP M2-GPT (145M) 12.9 11.6 10.9 Attention-Free, MLP-Free Transformer (355M) 11.4 9.8 9.1 Attention, MLP Hyena (360M) 11.3 9.8 9.2 Attention-Free, MLP M2-GPT (360M) 11.0 9.6 9.0 Attention-Free, MLP-Free accumulate rather than first multiplying and then accumulating. Recently, libraries such as Py Torch 2.0 [63] have added kernel fusion capabilities, although the very best performance usually still arises from handwritten kernels. Third, in order to better exploit memory locality, it is often fastest to load small blocks of memory, do intensive computation on them, and then write the results a tile at a time [83]. Finally, many algorithms also have hand-optimizations that can remove unnecessary computation or memory accesses [55]. Efficient algorithms usually make use of a combination of these techniques. For example, Flash Attention [15] uses all four to dramatically decrease both the latency and memory consumption of multi-head attention. Though we have made a modest effort to implement MONARCH MIXER efficiently, we think it likely that MONARCH MIXER could be further optimized by these techniques. Reducing flops: A first target for optimization is the multi-layer perceptron (MLP), owing to its ubiquity. A variety of structured sparse factorizations exist, many of which we draw on in this work [7, 11, 14, 16, 17, 19, 26, 91]. Attention is also a popular target for optimization. Recently, a plethora of sub-quadratic approximations of attention have emerged, that aim to approximate attention to reduce its quadratic complexity. Some methods rely on sparsification, relying on the fact that the attention matrix is extremely sparse at long sequence lengths [2, 23, 24, 42, 53]. Others use low-rank approximations of the attention matrix [13, 79, 91] or kernel methods instead [9, 41]. A subset use a combination of these techniques, such as [8, 72]. Finally, a third category of methods [27, 64] aim to replace attention entirely, relying on state-space models [33]. 7 Discussion and Conclusion We explore MONARCH MIXER (M2), a new architecture that is sub-quadratic in both sequence length and model dimension and is hardware-efficient on modern accelerators. We motivate M2 from both theoretical and systems performance perspectives and conduct a preliminary proof-ofconcept investigation into performance on masked language modeling, image classification, and causal language modeling. While our initial results are promising, our work is only a first step in this direction. The M2 layer can likely be further optimized with systems optimization techniques such as kernel fusion. Our work has also not been optimized for inference like more well-established models such as Transformers, or even more recent models such as state space models. It also remains to be seen whether M2 layers can have as widespread applicability as Transformers. We hope that these can be fruitful directions for future work. A discussion of broader impacts can be found in the Appendix. Acknowledgments We gratefully acknowledge the support of DARPA under Nos. FA86501827865 (SDH) and FA86501827882 (ASED); NIH under No. U54EB020405 (Mobilize), NSF under Nos. CCF1763315 (Beyond Sparsity), CCF1563078 (Volume to Velocity), and 1937301 (RTML); ONR under No. N000141712266 (Unifying Weak Supervision); the Moore Foundation, NXP, Xilinx, LETI-CEA, Intel, IBM, Microsoft, NEC, Toshiba, TSMC, ARM, Hitachi, BASF, Accenture, Ericsson, Qualcomm, Analog Devices, the Okawa Foundation, American Family Insurance, Google Cloud, Microsoft Azure, Swiss Re, Brown Institute for Media Innovation, Department of Defense (Do D) through the National Defense Science and Engineering Graduate Fellowship (NDSEG) Program, Fannie and John Hertz Foundation, National Science Foundation Graduate Research Fellowship Program, Texas Instruments Stanford Graduate Fellowship in Science and Engineering, and members of the Stanford DAWN project: Teradata, Facebook, Google, Ant Financial, NEC, VMWare, and Infosys. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright notation thereon. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views, policies, or endorsements, either expressed or implied, of DARPA, NIH, ONR, or the U.S. Government. JG and AR s work is supported by NSF grant# CCF-2247014. IJ s work is supported by an NSF Graduate Fellowship. [1] Iz Beltagy, Kyle Lo, and Arman Cohan. Scibert: A pretrained language model for scientific text. ar Xiv preprint ar Xiv:1903.10676, 2019. [2] Iz Beltagy, Matthew E Peters, and Arman Cohan. Longformer: The long-document transformer. ar Xiv preprint ar Xiv:2004.05150, 2020. [3] Emily M. Bender, Timnit Gebru, Angelina Mc Millan-Major, and Shmargaret Shmitchell. On the dangers of stochastic parrots: Can language models be too big? In Proceedings of the 2021 ACM Conference on Fairness, Accountability, and Transparency, FAcc T 21, page 610 623, New York, NY, USA, 2021. Association for Computing Machinery. [4] Lucas Beyer, Xiaohua Zhai, and Alexander Kolesnikov. Better plain vit baselines for imagenet1k. ar Xiv preprint ar Xiv:2205.01580, 2022. [5] Rishi Bommasani, Drew A Hudson, Ehsan Adeli, Russ Altman, Simran Arora, Sydney von Arx, Michael S Bernstein, Jeannette Bohg, Antoine Bosselut, Emma Brunskill, et al. On the opportunities and risks of foundation models. ar Xiv preprint ar Xiv:2108.07258, 2021. [6] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877 1901, 2020. [7] Beidi Chen, Tri Dao, Kaizhao Liang, Jiaming Yang, Zhao Song, Atri Rudra, and Christopher R e. Pixelated butterfly: Simple and efficient sparse training for neural network models. 2021. [8] Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher R e. Scatterbrain: Unifying sparse and low-rank attention. In Advances in Neural Information Processing Systems (Neur IPS), 2021. [9] Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. ar Xiv preprint ar Xiv:2009.14794, 2020. [10] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. ar Xiv preprint ar Xiv:2204.02311, 2022. [11] James W Cooley and John W Tukey. An algorithm for the machine calculation of complex fourier series. Mathematics of computation, 19(90):297 301, 1965. [12] Ekin D Cubuk, Barret Zoph, Jonathon Shlens, and Quoc V Le. Randaugment: Practical automated data augmentation with a reduced search space. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, pages 702 703, 2020. [13] Zihang Dai, Guokun Lai, Yiming Yang, and Quoc Le. Funnel-transformer: Filtering out sequential redundancy for efficient language processing. Advances in neural information processing systems, 33:4271 4282, 2020. [14] Tri Dao, Beidi Chen, Nimit S Sohoni, Arjun Desai, Michael Poli, Jessica Grogan, Alexander Liu, Aniruddh Rao, Atri Rudra, and Christopher R e. Monarch: Expressive structured matrices for efficient and accurate training. In International Conference on Machine Learning. PMLR, 2022. [15] Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher R e. Flash Attention: Fast and memory-efficient exact attention with IO-awareness. In Advances in Neural Information Processing Systems, 2022. [16] Tri Dao, Albert Gu, Matthew Eichhorn, Atri Rudra, and Christopher R e. Learning fast algorithms for linear transforms using butterfly factorizations, 2020. [17] Tri Dao, Nimit S. Sohoni, Albert Gu, Matthew Eichhorn, Amit Blonder, Megan Leszczynski, Atri Rudra, and Christopher R e. Kaleidoscope: An efficient, learnable representation for all structured linear maps, 2021. [18] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A largescale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248 255. Ieee, 2009. [19] Tim Dettmers and Luke Zettlemoyer. Sparse networks from scratch: Faster training without losing performance. ar Xiv preprint ar Xiv:1907.04840, 2019. [20] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. ar Xiv preprint ar Xiv:1810.04805, 2018. [21] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. In ar Xiv:1810.04805, 2019. [22] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. ar Xiv preprint ar Xiv:2010.11929, 2020. [23] Nan Du, Yanping Huang, Andrew M Dai, Simon Tong, Dmitry Lepikhin, Yuanzhong Xu, Maxim Krikun, Yanqi Zhou, Adams Wei Yu, Orhan Firat, et al. Glam: Efficient scaling of language models with mixture-of-experts. In International Conference on Machine Learning, pages 5547 5569. PMLR, 2022. [24] William Fedus, Barret Zoph, and Noam Shazeer. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. The Journal of Machine Learning Research, 23(1):5232 5270, 2022. [25] Wikimedia Foundation. Wikimedia downloads. [26] Jonathan Frankle and Michael Carbin. The lottery ticket hypothesis: Finding sparse, trainable neural networks. ar Xiv preprint ar Xiv:1803.03635, 2018. [27] Daniel Y Fu, Tri Dao, Khaled K Saab, Armin W Thomas, Atri Rudra, and Christopher R e. Hungry hungry hippos: Towards language modeling with state space models. International Conference on Learning Representations, 2023. [28] Daniel Y. Fu, Elliot L. Epstein, Eric Nguyen, Armin W. Thomas, Michael Zhang, Tri Dao, Atri Rudra, and Christopher R e. Simple hardware-efficient long convolutions for sequence modeling. International Conference on Machine Learning, 2023. [29] Morgan Funtowicz. Scaling up bert-like model inference on modern cpu - part 1, 2021. [30] Jonas Geiping and Tom Goldstein. Cramming: Training a language model on a single gpu in one day. ar Xiv:2212.14034v1, 2022. [31] Google. Bard, https://bard.google.com/. 2023. [32] Robert M Gray et al. Toeplitz and circulant matrices: A review. Foundations and Trends in Communications and Information Theory, 2(3):155 239, 2006. [33] Albert Gu, Karan Goel, and Christopher R e. Efficiently modeling long sequences with structured state spaces. ar Xiv preprint ar Xiv:2111.00396, 2021. [34] Nicholas Hale and Alex Townsend. An algorithm for the convolution of legendre series. SIAM Journal on Scientific Computing, 36(3):A1207 A1220, 2014. [35] Song Han, Huizi Mao, and William J Dally. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding. ar Xiv preprint ar Xiv:1510.00149, 2015. [36] Ramin Hasani, Mathias Lechner, Tsun-Hsuan Wang, Makram Chahine, Alexander Amini, and Daniela Rus. Liquid structural state-space models. ar Xiv preprint ar Xiv:2209.12951, 2022. [37] Dan Hendrycks, Norman Mu, Ekin D Cubuk, Barret Zoph, Justin Gilmer, and Balaji Lakshminarayanan. Augmix: A simple data processing method to improve robustness and uncertainty. ar Xiv preprint ar Xiv:1912.02781, 2019. [38] Peter Izsak, Moshe Berchansky, and Omer Levy. How to train bert with an academic budget. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 10644 10652, 2021. [39] Di Jin, Zhijing Jin, Joey Tianyi Zhou, and Peter Szolovits. Is bert really robust? a strong baseline for natural language attack on text classification and entailment. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 8018 8025, 2020. [40] Mandar Joshi, Danqi Chen, Yinhan Liu, Daniel S Weld, Luke Zettlemoyer, and Omer Levy. Spanbert: Improving pre-training by representing and predicting spans. Transactions of the Association for Computational Linguistics, 8:64 77, 2020. [41] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and Franc ois Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International Conference on Machine Learning, pages 5156 5165. PMLR, 2020. [42] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. ar Xiv preprint ar Xiv:2001.04451, 2020. [43] Jan Koco n, Igor Cichecki, Oliwier Kaszyca, Mateusz Kochanek, Dominika Szydło, Joanna Baran, Julita Bielaniewicz, Marcin Gruza, Arkadiusz Janz, Kamil Kanclerz, et al. Chatgpt: Jack of all trades, master of none. ar Xiv preprint ar Xiv:2302.10724, 2023. [44] Elias Konstantinidis and Yiannis Cotronis. A practical performance model for compute and memory bound gpu kernels. In 2015 23rd Euromicro International Conference on Parallel, Distributed, and Network-Based Processing, pages 651 658. IEEE, 2015. [45] MV Koroteev. Bert: a review of applications in natural language processing and understanding. ar Xiv preprint ar Xiv:2103.11943, 2021. [46] Mitsuru Kusumoto, Takuya Inoue, Gentaro Watanabe, Takuya Akiba, and Masanori Koyama. A graph theoretic framework of recomputation algorithms for memory-efficient backpropagation. Advances in Neural Information Processing Systems, 32, 2019. [47] Lagrange polynomial. Lagrange polynomial Wikipedia, the free encyclopedia, 2005. https: //en.wikipedia.org/wiki/Lagrange_polynomial. [48] Jinhyuk Lee, Wonjin Yoon, Sungdong Kim, Donghyeon Kim, Sunkyu Kim, Chan Ho So, and Jaewoo Kang. Biobert: a pre-trained biomedical language representation model for biomedical text mining. Bioinformatics, 36(4):1234 1240, 2020. [49] Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. Roberta: A robustly optimized bert pretraining approach. ar Xiv preprint ar Xiv:1907.11692, 2019. [50] Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, and Baining Guo. Swin transformer v2: Scaling up capacity and resolution. In International Conference on Computer Vision and Pattern Recognition (CVPR), 2022. [51] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, and Baining Guo. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2021. [52] Xiaofei Ma, Zhiguo Wang, Patrick Ng, Ramesh Nallapati, and Bing Xiang. Universal text representation from bert: An empirical study. ar Xiv preprint ar Xiv:1910.07973, 2019. [53] Xuezhe Ma, Xiang Kong, Sinong Wang, Chunting Zhou, Jonathan May, Hao Ma, and Luke Zettlemoyer. Luna: Linear unified nested attention. Advances in Neural Information Processing Systems, 34:2441 2453, 2021. [54] Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig, Jonathan May, and Luke Zettlemoyer. Mega: moving average equipped gated attention. ar Xiv preprint ar Xiv:2209.10655, 2022. [55] Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. ar Xiv preprint ar Xiv:1805.02867, 2018. [56] Derek Miller. Leveraging bert for extractive text summarization on lectures. ar Xiv preprint ar Xiv:1906.04165, 2019. [57] Marcin Moczulski, Misha Denil, Jeremy Appleyard, and Nando de Freitas. Acdc: A structured efficient linear layer. ar Xiv preprint ar Xiv:1511.05946, 2015. [58] NVIDIA. Nvidia Tesla V100 GPU architecture, 2017. [59] NVIDIA. Nvidia A100 tensor core GPU architecture, 2020. [60] NVIDIA. Nvidia H100 tensor core GPU architecture, 2022. [61] NVIDIA. cu BLAS, 2023. [62] Open AI. Gpt-4 technical report, 2023. [63] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, et al. Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32, 2019. [64] Michael Poli, Stefano Massaroli, Eric Nguyen, Daniel Y Fu, Tri Dao, Stephen Baccus, Yoshua Bengio, Stefano Ermon, and Christopher R e. Hyena hierarchy: Towards larger convolutional language models. International Conference on Machine Learning, 2023. [65] Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving language understanding by generative pre-training. 2018. [66] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J. Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. ar Xiv preprint ar Xiv:1910.10683, 2019. [67] David W Romero, R Bruintjes, Erik J Bekkers, Jakub M Tomczak, Mark Hoogendoorn, and JC van Gemert. Flexconv: Continuous kernel convolutions with differentiable kernel sizes. In 10th International Conference on Learning Representations, 2022. [68] David W Romero, David M Knigge, Albert Gu, Erik J Bekkers, Efstratios Gavves, Jakub M Tomczak, and Mark Hoogendoorn. Towards a general purpose cnn for long range dependencies in {N} d. ar Xiv preprint ar Xiv:2206.03398, 2022. [69] David W Romero, Anna Kuzina, Erik J Bekkers, Jakub Mikolaj Tomczak, and Mark Hoogendoorn. Ckconv: Continuous kernel convolution for sequential data. In International Conference on Learning Representations, 2021. [70] Andreas Steiner, Alexander Kolesnikov, Xiaohua Zhai, Ross Wightman, Jakob Uszkoreit, and Lucas Beyer. How to train your vit? data, augmentation, and regularization in vision transformers. ar Xiv preprint ar Xiv:2106.10270, 2021. [71] G. Szeg o. Orthogonal Polynomials. Number v.23 in American Mathematical Society colloquium publications. American Mathematical Society, 1967. [72] Yi Tay, Mostafa Dehghani, Dara Bahri, and Donald Metzler. Efficient transformers: A survey. ACM Computing Surveys, 55(6):1 28, 2022. [73] Ilya O Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, Daniel Keysers, Jakob Uszkoreit, et al. Mlpmixer: An all-mlp architecture for vision. Advances in neural information processing systems, 34:24261 24272, 2021. [74] Asher Trockman and J Zico Kolter. Patches are all you need? Transactions on Machine Learning Research, 2023. [75] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. volume 30, 2017. [76] Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R Bowman. Glue: A multi-task benchmark and analysis platform for natural language understanding. ar Xiv:1804.07461, 2018. [77] Junxiong Wang, Jing Nathan Yan, Albert Gu, and Alexander M Rush. Pretraining without attention. ar Xiv preprint ar Xiv:2212.10544, 2022. [78] Qipeng Wang, Mengwei Xu, Chao Jin, Xinran Dong, Jinliang Yuan, Xin Jin, Gang Huang, Yunxin Liu, and Xuanzhe Liu. Melon: Breaking the memory wall for resource-efficient ondevice machine learning. In Proceedings of the 20th Annual International Conference on Mobile Systems, Applications and Services, pages 450 463, 2022. [79] Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with linear complexity. ar Xiv preprint ar Xiv:2006.04768, 2020. [80] Laura Weidinger, John Mellor, Maribeth Rauh, Conor Griffin, Jonathan Uesato, Po-Sen Huang, Myra Cheng, Mia Glaese, Borja Balle, Atoosa Kasirzadeh, et al. Ethical and social risks of harm from language models. ar Xiv preprint ar Xiv:2112.04359, 2021. [81] Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Perric Cistac, Clara Ma, Yacine Jernite, Julien Plu, Canwen Xu, Teven Le Scao, Sylvain Gugger, Mariama Drame, Quentin Lhoest, and Alexander M. Rush. Transformers: State-of-theart natural language processing. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, 2020. [82] Kuan Xu and Ana F. Loureiro. Spectral approximation of convolution operators. SIAM Journal on Scientific Computing, 40(4):A2336 A2355, 2018. [83] Yufan Xu, Saurabh Raje, Atanas Rountev, Gerald Sabin, Aravind Sukumaran-Rajam, and P Sadayappan. Training of deep learning pipelines on memory-constrained gpus via segmented fused-tiled execution. In Proceedings of the 31st ACM SIGPLAN International Conference on Compiler Construction, pages 104 116, 2022. [84] Lili Yu, D aniel Simig, Colin Flaherty, Armen Aghajanyan, Luke Zettlemoyer, and Mike Lewis. Megabyte: Predicting million-byte sequences with multiscale transformers, 2023. [85] Shanshan Yu, Jindian Su, and Da Luo. Improving bert-based text classification with auxiliary sentence and domain knowledge. IEEE Access, 7:176600 176612, 2019. [86] Li Yuan, Yunpeng Chen, Tao Wang, Weihao Yu, Yujun Shi, Zi-Hang Jiang, Francis EH Tay, Jiashi Feng, and Shuicheng Yan. Tokens-to-token vit: Training vision transformers from scratch on imagenet. In Proceedings of the IEEE/CVF international conference on computer vision, pages 558 567, 2021. [87] Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. Cutmix: Regularization strategy to train strong classifiers with localizable features. In Proceedings of the IEEE/CVF international conference on computer vision, pages 6023 6032, 2019. [88] Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. ar Xiv preprint ar Xiv:1710.09412, 2017. [89] Tianyi Zhang, Varsha Kishore, Felix Wu, Kilian Q Weinberger, and Yoav Artzi. Bertscore: Evaluating text generation with bert. ar Xiv preprint ar Xiv:1904.09675, 2019. [90] Zhun Zhong, Liang Zheng, Guoliang Kang, Shaozi Li, and Yi Yang. Random erasing data augmentation. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 13001 13008, 2020. [91] Chen Zhu, Wei Ping, Chaowei Xiao, Mohammad Shoeybi, Tom Goldstein, Anima Anandkumar, and Bryan Catanzaro. Long-short transformer: Efficient transformers for language and vision. Advances in Neural Information Processing Systems, 34:17723 17736, 2021. [92] Yukun Zhu, Ryan Kiros, Rich Zemel, Ruslan Salakhutdinov, Raquel Urtasun, Antonio Torralba, and Sanja Fidler. Aligning books and movies: Towards story-like visual explanations by watching movies and reading books. In The IEEE International Conference on Computer Vision (ICCV), December 2015. Author Contributions D.Y.F. Conceptualized the research; coordinated collaborations; developed M2 architectures; led experimental and implementation efforts; assisted in development of theoretical results; coordinated writing. S.A. Worked on developing M2 architectures; worked on implementing and conducting BERT experiments; worked on implementing and performing CPU experiments; assisted in writing and framing the work. J.G. Led development of theory and causal algorithms; wrote Appendix D. I.J. Led development of theory and causal algorithms; wrote Appendix D. S.E. Worked on writing and framing the work; assisted in development of M2 architectures; assisted in the optimized M2 implementation; conducted mixer benchmarks; assisted with BERT experiments; conducted Swin Image Net experiments. A.W.T. Conducted Vi T experiments; assisted in writing. B.S. Assisted in optimized M2 implementation; conducted mixer benchmarks; assisted in writing. M.P. Assisted in development of M2-GPT architecture. A.R. Supervised theory development; developed proofs; reviewed manuscript. C.R. Supervised research; reviewed manuscript. Simran Arora, Jessica Grogan, Isys Johnson, Sabri Eyuboglu, and Armin Thomas contributed equally to this work. Appendix A discusses broader impacts of our work. Appendix B presents additional experiments. Appendix C gives details for the experiments, including model architectures and hyperparameters. Appendix D gives missing details and proofs for the theoretical analysis, as well as generalizations to broader results. Appendix E gives a Py Torch code listing of an M2 layer. A Broader Impacts Our work seeks to understand the fundamental capabilities and limitations of newly-emerging model architectures. As the amount of data and model size grows, we also seek to understand how to make training these models more efficient, both in terms of the amount of training context and the model size. This potentially connects to energy savings during model development and deployment, as well as making machine learning models accessible to a larger population of people. However, as with any machine learning models, developing new techniques may impact a wide range of applications, each with potential benefits and harms. Making language model training cheaper and longer context may make it cheaper to spread disinformation. Similarly, improving the efficiency of model training may not reduce the overall environmental footprint of training, since the same resources may be used to train more models, or train the same models for longer. While our work makes partial progress on the fronts of efficiency and understanding, it does not explicitly address the issues of fairness and bias in language models. In addition, our work demonstrates a proof-of-concept; it has error modes, and we recognize the inherent risks of training and using machine learning models, including language models. Detailed discussions of these risks are in [3, 5, 80]. B Additional Experiments B.1 Per-Task GLUE Numbers We report full GLUE numbers for M2-BERT-base and M2-BERT-large in Table 9. B.2 Additional Throughput Results We report the throughput of M2-BERT-base (80M) compared to BERT models of the same size (BERT-base with fewer parameters), as well as the throughput of M2-BERT-large (260M) compared to BERT-large. Table 9: Fine-tuning performance on GLUE [76]. We report the standard metrics F1 scores for QQP and MRPC, Matthew s correlation for Co LA, Spearman s correlation for STS-B, and accuracy for the remaining tasks, following the procedure from [38]. Model MNLI (m / mm) RTE QNLI QQP SST2 STS-B Co LA MRPC Average M2-BERT-base (80M) 78.4 / 78.6 68.5 84.6 86.7 92.0 86.3 53.0 89.8 79.9 M2-BERT-base (110M) 79.6 / 80.5 69.3 86.0 87.0 92.3 86.9 56.0 89.2 80.9 M2-BERT-large (260M) 81.7 / 81.9 72.8 84.7 87.8 93.3 88.0 59.2 90.0 82.2 M2-BERT-large (341M) 82.2 / 82.3 75.0 87.0 87.7 92.4 88.3 59.6 90.1 82.8 Table 10: Throughput in tokens/ms by context length for M2-BERT-base (80M) compared to 80M BERT models. Model 512 1024 2048 4096 8192 HF BERT (79M) 248.4 157.3 86.0 46.8 OOM Flash Attention BERT (79M) 433.3 425.1 335.2 217.4 122.6 M2-BERT-base (80M) 386.3 380.7 378.9 353.9 320.1 M2 Speedup over HF BERT (80M) 1.6 2.4 4.4 7.5 Table 10 compares the performance of M2-BERT-base (80M) to BERT models parameter-matched to 80M parameters. M2 is slower than Flash Attention for sequence lengths 512 and 1K, but outperforms Flash Attention starting at sequence length 2K. We believe further optimization of the M2 kernel can close the gap to Flash Attention for short sequences. Table 11 compares M2-BERT-large (260M) to BERT-large. Trends are mostly similar to comparisons against BERT-base; M2 nearly matches Flash Attention at sequence length 512, and outperforms it for sequence length 1K and longer. We also see up to 4.3 speedup over Hugging Face BERT-large at sequence length 2K. B.3 Image Net Comparison against Swin Table 12 reports the results of replacing attention and MLP in Swin-V2 using M2 as a drop-in replacement. Surprisingly, Swin-M2 outperforms Swin-MLP-B, is competitive with Swin-V1-B, and comes within 1 point of Swin-V2-B, even without any hyperparameter tuning or architecture adjustment from the Vi T formula. We expect that performance may improve further with hyperparameter tuning specific to M2. B.4 Speech Applications Table 13 presents the performance of M2 on Speech Commands-10, a speech classification task over raw 1-second clips sampled at 16 k Hz. M2 is competitive with state-of-the-art architectures on this task. B.5 CIFAR10 Table 14 shows the performance of MONARCH MIXER on CIFAR10. The trends are largely the same as on Image Net. B.6 Learnable Monarch Matrices in Sequence Mixer In most of our models, we have used fixed Monarch matrices for the sequence mixer, and learnable Monarch matrices for the dimension mixer. Table 15 presents an experiment evaluating using learnable Monarch matrices for the sequence mixer on the sequential CIFAR task. We use a nongated convolutional architecture based off long convolutions, as presented in [28]. Learning the Monarch matrices in the sequence mixer yields 1.5 points of lift. Table 11: Throughput in tokens/ms by context length for M2-BERT-large (260M) compared to BERT-large. Model 512 1024 2048 4096 8192 HF BERT-large (340M) 75.4 47.1 25.2 OOM OOM Flash Attention BERT-large (340M) 125.0 111.9 91.6 54.5 OOM M2-BERT-large (260M) 122.5 118.6 109.4 94.5 75.0 M2 Speedup over HF BERT-large (340M) 1.6 2.5 4.3 - - Table 12: Image Net accuracy of Swin models. Model Image Net (acc@1) Image Net (acc@5) Swin-MLP-B 81.3 95.3 Swin-V1-B 83.5 96.5 Swin-V2-B 84.2 96.9 M2-Swin-B 83.5 96.7 B.7 Roofline Analysis Figure 4 shows a Roofline analysis of a simple Py Torch implementation of a single M2 operator M 1(Mu Mk on an A100 GPU, with 4K input length. The operation is more dominated by the data movement operations, which helps explain why performance is higher on newer architectures like RTX 4090 (which have faster and larger L2 cache). B.8 Associative Recall In Table 16, we present a simple experiment demonstrating the causal parameterization of M2 on associative recall, a synthetic language designed to test in-context learning. The model demonstrates in-context learning abilities in sequences up to 128K tokens, but Transformers do not scale past 8K. B.9 BERT Experiments with Alternative Architecture Here, we report results using an older version of the M2-BERT architecture, that uses non-gated convolutions and is trained on English Wikipedia [25] and English Bookcorpus [92]. For clarity, we refer to this model as M1-BERT. We found that M1-BERT could match Transformers on MLM quality, but underperformed on downstream fine-tuning. We attribute this gap in performance to sub-optimal training hyperparameters (optimized for throughput using NVIDIA MLPerf hyperparameters) as well as a sub-optimal architecture. We report results here for completeness, but refer to the gated convolution architecture in the main body as the proper M2-BERT model. These models followed the reference implementations and hyperparameters from Hugging Face Transformers examples [81] and Nvidia Deep Learning examples (https://github.com/NVIDIA/ Deep Learning Examples). In particular, we use the LAMB optimizer with a learning rate of 5e 3. For each sequence length, we use as large a minibatch size as possible that fits on the GPU (A10080GB in Table 17 and V100 in Table 18). We set the gradient accumulation to reach a global batch size of 65, 536 sequences. To investigate the effect of sequence length, each model is trained for a fixed sequence length in a single phase of training (in contrast to some training protocols, which train the model in multiple phases, each at different sequence lengths). Time to a Fixed Pretraining Quality on 8x A100 We compare time to a fixed pretraining quality, training M1-BERT-base on English Wikipedia [25] and English Bookcorpus [92]. We compare against BERT-base trained with FLASHATTENTION [15], as well as the Monarch-BERT-base implementation from the original Monarch paper [14]. We measure wall-clock time for M1-BERT and the base Transformer to reach 50% in masked language modeling accuracy on 8x A100 Nvidia GPUs with 80GB memory each. Table 17 summarizes results. In short sequence lengths, M1-BERT is Table 13: Accuracy on Speech-Commands 10. An x means that the model did not fit in memory. M2 S4 Wave Gan-D Transformer Performer CKConv 97.9 97.5 96.3 x 30.8 71.7 Table 14: Accuracy on CIFAR-10. Model Top-1% Description Vi T (1.2M) 78.6 Attention + MLP Vi T + Monarch (607K) 79.0 Attention, MLP-Free Hyena Vi T (1.3M) 80.6 Attention-Free + MLP Hyena Vi T-M2 (741K) 80.8 Attention-Free + MLP Free comparable to FLASHATTENTION, even without using a heavily-optimized fused kernel. In longer sequence lengths, the FLOP savings make M1-BERT more efficient up to 2.4 faster than BERT with FLASHATTENTION at sequence length 4096. BERT in Half a Day Inspired by recent work focusing on training under limited resource constraints [30], we measure how far we can get when training on a single V100 GPU in 12 hours. In Table 18, we report the masked language modeling accuracy achieved by the same set of models and sequence lengths (except for the FLASHATTENTION baseline, which is not supported on V100). We observe M1-BERT both achieves higher accuracy within the time limit and can be trained at longer sequence lengths than the baseline architectures. Downstream Fine-Tuning We evaluate the quality of M1-BERT-base models on the GLUE benchmark [76]. Table 19 shows fine-tuning performance on the GLUE tasks, using the same hyperparameters and 5 epochs for all tasks and both models. M1-BERT-base is competitive with Transformers trained using MLPerf hyperparameters on Bookcorpus and Wikitext, but underperforms fully-trained transformers and M2-BERT-base. C Experiment Details C.1 Model Architectures In this section, we describe the exact model architectures we used for each task, including the design of the block (residuals and gating). We additionally release our code for reproducibility, BERT Language Modeling The M2-BERT architectures use a standard BERT backbone, but replace the attention with bidirectional gated convolutions and replace the linear layers in the MLPs with block-diagonal matrices. All the M2-BERT architectures use an expansion factor of four. M2-BERT-base (80M) has a model width of 768 and 12 layers; M2-BERT-base (110M) has a model width of 960 and 12 layers; M2-BERT-large (260M) has a model width of 1536 and 12 layers; and M2-BERT-large (341M) has a model width of 1792 and 12 layers. We train all these models on C4 for 70,000 steps, with sequence length 128, and global batch size 4096 sequences. For all the models, we use decoupled Adam W with learning rate 8e-4 and decoupled weight decay 1e-5. We use linear learning rate decay with a warmup of 6% of the steps, and we use MLM masking percentage of 30%. For GLUE fine-tuning, we do a small search of learning rate, weight decay, and number of epochs. Following [38], we fine-tune RTE, MRPC, and STS-B from the MNLI checkpoint. We fine-tune all tasks with sequence length 128. For some tasks, we also pool the embeddings of all the non-padding tokens instead of using the CLS token. The final hyperparameters for M2-BERT-base (80M) are decoupled Adam W with learning rate 5e-5 and weight decay 5e-6 for 3 epochs for MNLI; Adam W with learning rate 5e-5 and weight decay 0.01 for 6 epochs for RTE; Adam W with learning rate 3e-5 and weight decay 0.01 for 10 epochs on QQP; Adam W with learning rate 5e-5 and weight decay 1e-5 for 10 epochs with average pooling for QNLI; decoupled Adam W with learning rate 3e-5 and weight decay 3ed-6 for 3 epochs for SST-2; Adam W Table 15: Accuracy on sequential CIFAR for fixed vs. learnable Monarch in the sequence mixer. Model s CIFAR Accuracy M2, Fixed Monarch 91.0 M2, Learnable Monarch 92.5 Figure 4: Roofline plot of a Py Torch implementation of a single M2 operator M 1(Mu Mk). with learning rate 7e-5 and weight decay 0.01 for 10 epochs for STS-B; Adam W with learning rate 5e-5 and weight decay 0.01 for 10 epochs for MRPC; and decoupled Adam W with learning rate 5e-5 and weight decay 5e-6 for 10 epochs for COLA. For M2-BERT-base (110M), the hyperparameters are decoupled Adam W with learning rate 5e-5 and weight decay 5e-6 for 3 epochs for MNLI; decoupled Adam W with learning rate 1e-5 and weight decay 1e-6 for 3 epochs for RTE; decoupled Adam W with learning rate 3e-5 and weight decay 3e-6 for 5 epochs on QQP; decoupled Adam W with learning rate 5e-5 and weight decay 1e-5 for 10 epochs with average pooling for QNLI; decoupled Adam W with learning rate 3e-5 and weight decay 3ed-6 for 3 epochs for SST-2; decoupled Adam W with learning rate 8e-5 and weight decay 3e-6 for 10 epochs for STS-B; decoupled Adam W with learning rate 8e-5 and weight decay 8e-5 for 10 epochs for MRPC; and Adam W with learning rate 8e-5 and weight decay 5e-6 for 10 epochs for COLA. For M2-BERT-large (260M), the hyperparameters are decoupled Adam W with learning rate 5e-5 and weight decay 5e-6 for 3 epochs for MNLI; decoupled Adam W with learning rate 1e-5 and weight decay 1e-6 for 3 epochs for RTE; decoupled Adam W with learning rate 3e-5 and weight decay 3e-6 for 5 epochs on QQP; decoupled Adam W with learning rate 5e-5 and weight decay 1e-5 for 10 epochs for QNLI; decoupled Adam W with learning rate 3e-5 and weight decay 3ed-6 for 3 epochs for SST-2; decoupled Adam W with learning rate 7e-5 and weight decay 3e-6 for 10 epochs for STS-B; decoupled Adam W with learning rate 8e-5 and weight decay 8e-6 for 10 epochs for MRPC; and Adam W with learning rate 5e-5 and weight decay 5e-6 for 10 epochs for COLA. For M2-BERT-large (341M), the hyperparameters are decoupled Adam W with learning rate 5e-5 and weight decay 5e-6 for 3 epochs for MNLI; Adam W with learning rate 5e-5 and weight decay 1e-6 for 2 epochs for RTE; decoupled Adam W with learning rate 3e-5 and weight decay 3e-6 for 5 epochs on QQP; decoupled Adam W with learning rate 5e-5 and weight decay 1e-6 for 10 epochs for QNLI; decoupled Adam W with learning rate 3e-5 and weight decay 3ed-6 for 3 epochs for SST-2; decoupled Adam W with learning rate 8e-5 and weight decay 3e-5 for 8 epochs for STS-B; decoupled Adam W with learning rate 8e-5 and weight decay 8e-6 for 10 epochs for MRPC; and decoupled Adam W with learning rate 5e-5 and weight decay 1e-6 for 10 epochs for COLA. Vi T We use a standard Vi T model architecture as base [22]. In line with recent improvements to the Vi T architecture [4, 70, 86], we use sinusoidal position embeddings and global average-pooling (GAP) instead of a class token. Table 16: In-context learning performance on associative recall at various sequence lengths, vocab size 20. indicates the Transformer did not finish in a week. Model 0.5K 2K 8K 32K 128K Transformer 100.0 100.0 100.0 MONARCH MIXER 98.7 99.4 99.4 99.4 99.4 Table 17: Time in hours to reach 50% masked language modeling validation accuracy on 8x A100 with different sequence lengths. Model 512 1024 2048 4096 Architecture Details BERT-base-FLASHATTENTION (110M) 2.7 3.8 5.7 13.2 Attention, MLP BERT-base-Hugging Face (110M) 3.3 5.6 13.1 26.7 Attention, MLP BERT-Monarch-base (80M) 3.1 4.7 10.3 22.1 Attention, MLP-free M1-BERT-base (55M) 2.5 3.5 4.0 5.5 Attention-Free, MLP-free Speedup 1.1 1.1 1.3 2.4 We adapt the Vi T architecture by replacing its MLP and/or attention components with Monarch Matrices (similar to our adaptation of BERT): We replace the MLP with randomly initialized Monarch Matrices of the same dimension as the dense matrices of the MLP and learn those matrices during training, setting the number of blocks in the block-diagonal matrices to 4. We replace attention with the recently introduced Hyena operator [64]. The Hyena operator represents a recurrence of two efficient sub-quadratic primitives, an implicit long convolution and multiplicative element-wise gating of the projected input. Hyena operators apply the FFT algorithm to achieve fast long convolutions in sub-quadratic time. We further adapt the Hyena operator by replacing its long convolutions with the M2 operator and setting the Monarch Matrices to the DFT and inverse DFT. Vi T for Image Net-1k In line with other work [4, 14, 64, 70], we use a Vi T-base architecture with 12 layers, a hidden size of 768, 12 attention heads per layer, an intermediate size of the MLP projection of 3, 072, and a patch size of 16 16 pixels. For optimization, we follow the training procedure of T2T-Vi T [86], including augmentations such as Rand Augment [12] (magnitude = 9, magnitude-std = 0.5, layers = 2), Mixup [88] (α = 0.8), Cut Mix [87] (α = 1.0), Random erasing [90] (rate = 0.25), and Aug Mix [37]. See Table 20 for all other training settings. Vi T for CIFAR-10 We use a Vi T architecture with 6 layers, a hidden size of 128, 8 attention heads per layer, an intermediate size of the MLP projection of 512, and a patch size of 4 4 pixels. We further tune weight decay (0 or 0.1), stochastic depth rate (0 or 0.1), and base learning rate (1e 4 or 3e 4 or 1e 3) and report the test performance for the model variant that achieved the highest accuracy in a separate held-out validation dataset (randomly selected 10% of training data). We also apply an early stopping rule such that training is stopped if the model s validation loss does not improve for 10 training epochs. See Table 20 for all other training settings. GPT Causal Language Modeling Similarly to our Vi T approach, we also replace attention with the Hyena operator, using the same architecture as in [64] as a starting point. The Hyena architecture has two convolutions, which can be computed using the FFT convolution theorem. In our architecture, we additionally replace these FFT operations with causal Monarch matrices. In addition, we re-use the heads extension from the H3 architecture [27]. The heads extension groups the model dimension into heads, ties together the long convolution parameters in each head, and then computes the outer product between different input projections. An algorithmic listing adapted from the H3 paper [27] is provided in Listing 1, with updates to replace the SSM layers with Hyena convolutions. We use a head dimension of 16. Setting the head dimension to be 1 and replacing the Monarch matrices with FFT is equivalent to the Hyena layer. Table 18: Masked language modeling validation accuracy achieved on a single V100 in 12 hours with different sequence lengths. indicates the model does not fit on device with a batch size of 1. Model 512 1024 2048 4096 8192 Architecture Details BERT-base (110M) 11.5 7.8 6.8 Attention, MLP BERT-Monarch-base 6.9 8.5 6.8 Attention, MLP-Free M1-BERT-base 20.2 20.2 20.1 17.1 12.9 Attention-Free, MLP-Free Table 19: Fine-tuning performance on the GLUE benchmark [76], after pretraining on Wikipedia and Bookcorpus. We report the standard metrics F1 scores for QQP and MRPC, Matthew s correlation for Co LA, Spearman s correlation for STS-B, and accuracy for the remaining tasks [21]. Model MNLI (m / mm) RTE QNLI QQP SST2 STS-B Co LA MRPC Architecture Details BERT no pretrain 34.1 / 34.1 47.3 50.0 68.6 79.9 17.8 0.0 77.9 Attention, MLP BERT-base 74.5 / 74.7 55.6 69.3 81.8 83.9 19.8 12.1 74.2 Attention, MLP M1-BERT-base 69.9 / 70.5 53.1 73.2 81.4 85.2 68.1 33.6 75.4 Attention-free, MLP-free Algorithm 1 M2 Hyena Layer with Heads Input: Input sequence u RN d from the previous layer, weight matrices WX1, WX2, WV , WO Rd d, causal Monarch matrix M, short convolution kernels K1, K2, K3, a Hyena convolution kernel Klong, head dimension dh. Output: Output sequence y RN d Compute X1 = u WX1, X2 = u WX2, V = u WV RN d. Pass X1, X2, V each through the short convolution using the causal Monarch matrices: X1, X2, V = M 1(MX1 MK1), M 1(MX2 MK2), M 1(MV MK3). Split X1, X2, V into H heads (X1 (h), X2 (h), V (h) for h = 1, . . . , H), each a sequence of N vectors of size dh = d/H. for 1 h H do Take the batched outer product X2 (h)(V (h)) RN dh dh (batched in the N-dimension) and pass it through the long convolution using the causal Monarch: XV(h) = M 1(MX2 (h)(V (h)) MKlong) RN dh dh. Batch-multiply by X1: O(h) = [X1 (h) 1 XV(h) 1 , . . . , X1 (h) N XV(h) N ] RN dh (batched in the Ndimension). Concatenate the output O(h) of each head, and multiply by the output projection matrix WO Rd d. Finally, we remove the MLP layers entirely (equivalent to replacing the layer with an identity), and make the model wider to compensate (the depths match the equivalent Hyena models). The small model has a model width of 1160 with 18 layers and uses a learning rate of 0.0006, and the medium model has model width of 1344 with 40 layers and uses a learning rate of 0.0008. All other hyperparameters match the Hyena models [64]. D Missing details from Section 4 This section contains all the missing details (including proofs) from Section 4. In Appendix D.1, we review some definitions and results on multi-variate polynomials and set some notation needed for this section. In Appendix D.2, we explicitly connect Monarch matrices for p = 2 and bivariate polynomial evaluation. Specifically, we prove Theorem 1 and Theorem 2. Then in Appendix D.3 we show how to instantiate the bivariate basis polynomials so that we get a causal map. This includes converting the bivariate polynomials to univariate polynomials (with evaluations over the Nth roots of unity) and this proves Theorem 3. We then show how this causal map can be implemented only using GEMMs (and O N 3/2 FLOPs) in Appendix D.4. Next, we note that while our evaluations points are over complex numbers, our input and output to the Monarch convolution layers are over reals. Hence, it is natural to wonder if we can implement the entire layer just with operations over real numbers. One potential advantage of this is that we theoretically only have to keep N real numbers for intermediate results (instead of 2N reals Table 20: Vi T training settings. Image Net-1k CIFAR-10 Optimizer Adam W Optimizer momentum β1, β2 = 0.9, 0.999 Learning rate schedule Cosine decay w/ linear warmup Dropout rate 0 Label smoothing 0.1 Image size 224 x 224 32 x 32 Base learning rate 1e-3 {1e-4, 3e-4, 1e-3} Batch size 1024 512 Training epochs 300 up to 500 Warmup epochs 10 5 Stochastic depth rate 0.1 {0, 0.1} Weight decay 0.05 {0, 0.1} numbers when we keep track of vectors in CN). This can reduce the data movement costs. Further, multiplication of two complex numbers requires six operations over real numbers (four multiplication and two addition). Thus, moving to an implementation that only uses real numbers could potentially lead to wall clock time speedup. We propose one such scheme in Appendix D.5 that proves a version of Theorem 3 just over reals by moving to the Chebyshev basis (instead of the standard monomial basis). This creates new technical challenges, which we also address. Finally, we generalize our results to arbitrary p 2 in Appendix D.6. We would like to point out that to get a causal map (in Theorem 17) we need to embed input vectors of size n into vectors of size N = 2p n + O n1 1/p . For p = 2, we avoided the blowup of 22 = 4 with a blowup of 2 instead (via Theorem 3). Whether this is possible to do (i.e. have a blowup of 2 instead of 2p) for p > 2 is an interesting direction for future work. Further, the matrices that lead to causal map can be represented with O p N 2/p parameters while the matrices in Theorem 3 use more parameters. Extending the causal map for p > 2 that uses O N 1+ 1 p parameters is an exciting direction for future work. D.1 Background and Notation We collect known facts and definitions about multi-variate polynomials in Appendix D.1.1 and recall some notation from [14] in Appendix D.1.2. These will be needed throughout this appendix section. D.1.1 Multi-variate Polynomials Basic Definitions Let p 1 be an integer. We recollect some definitions on p-variate polynomials (over R) in variables X0, . . . , Xp 1. When p {1, 2}, we will use variables in {X, Y, Z} for notational simplicity. We will use X to denote the vector of variables (X0, . . . , Xp 1). Further for j Zp 0, we use the notation a=0 Xja a . Xj is a (standard basis) monomial, where j = (j0, . . . , jp 1). A generic p-variate polynomial is defined as (with standard monomial representation) where the coefficient qj R. We will need the following notion of degrees: Definition 2 (Degree). Let 0 a < p. The degree of Xa in Xj (with j = (j0, . . . , jp 1)) is ja. The degree of Xa of q(X), denoted by deg Xa(q) is the maximum degree of Xa over all monomials Xj with qj = 0. Note that for p = 1 the above coincides with the usual notion of degree of a univariate polynomial q(Z), in which case we just use deg(q(Z)) to denote deg Z(q(Z)). We will need the notion of taking mod of a p-variate polynomial with p-tuple of polynomials. The notion of mod is well defined for a univariate polynomial (which we will assume as a given below) but in general for arbitrary p-variate polynomials q(X) and q (X), the operation q(X) mod q (X) is not well defined. However, we will only need the following restricted operation: Definition 3. Let p 1. Fix a p-tuple of polynomials R0(X0), . . . , Rp 1(Xp 1). Then for any j Zp 0, we define Xj mod (R0 (X0) , . . . , Rp 1 (Xp 1)) = Xja mod (Ra (Xa)) . For a general polynomial p(X), p(X) mod (R0 (X0) , . . . , Rp 1 (Xp 1)) is defined by extending the definition for Xj by linearity. Polynomial Evaluation Given a p-variate polynomial q(X) and an point a Rp, the evaluation of q at a denoted by q(a) is evaluation of q as a function at a. Given subsets Sa C, we define q(X) evaluated at p 1 a=0Sa as the vector of values q(a) overall a p 1 a=0Sa. In this paper, we will in many cases evaluate polynomials at the appropriate roots of unity. Specifically for an integer N, we will define ωN = e2πι/N and note that the Nth roots of unity is the set {ωi N|0 i < N}. Polynomial Interpolation We now recall univariate and bivariate polynomial interpolation results (proved via the Lagrange basis), which we will use in later subsections. Theorem 4. Let D 1 be an integer. Given yi for 0 i < D and αi for 0 i < D there exists a unique univariate polynomial P(X) with deg(P) < D , such that for all 0 i < D, P(αi) = yi. (7) Proof. This proof is based on the Wikipedia entry for Lagrange polynomials [47]. Given a sequence of values αi for 0 i < D s.t. αi = αj , i = j, the Lagrange basis for polynomials of degree < D for these values is the set of each polynomials {p0(X), p1(X), . . . p D 1(X)} each of degree D 1. Each basis polynomial are defined as: pi(X) = X α0 αi α0 X αi 1 αi αi 1 X αi+1 αi αi+1 X αD 1 αi αD 1 = Y X αj αi αj . (8) By definition, pi(αj) = 1 for j = i 0 otherwise . (9) The Lagrange interpolating polynomial for those nodes through the corresponding values yi for 0 i < D is the linear combination: i=0 yi pi(X). (10) By (9), for all 0 i < D: P(αi) = yi. (11) Finally, the interpolating polynomial is unique. Assume there is another polynomial M(X) of degree < D such that M(αi) = yi for all 0 i < D. Then the difference M(X) P(X) is 0 at D distinct points αi for 0 i < D. And the only polynomials of degree < D with more than D 1 roots is the 0 polynomial. So, M(X) = P(X). Theorem 5. Let DX, DY 1 be integers. Given values yij for 0 i < DX, 0 j < DY and DX distinct points (α0, . . . , αDX 1), DY distinct points (β0, . . . , βDY 1) there exists a unique bivariate polynomial P(X, Y ) with deg X(P) < DX , deg Y (P) < DY , such that for all 0 i < DX, 0 j < DY : P(αi, βj) = yij. (12) Proof. Define i=0 yij pi(X) pj(Y ), (13) where pi and pj are Lagrange basis polynomials defined in the proof of Theorem 4 such that for 0 i, k < DX, pi(αk) = 1 for i = k 0 otherwise (14) and for 0 j, ℓ< DY p(βℓ) = 1 for k = ℓ 0 otherwise . (15) From above, we have for all i, j, k, ℓ pi(αk) pj(βℓ) = 1 for i = k and j = ℓ 0 otherwise . (16) Then, for all 0 i < DX , 0 j < DY : P(αi, βj) = yij. (17) By definition of Lagrange basis polynomials, deg X(P) < DX and deg Y (P) < DY . Finally, the interpolating polynomial is unique. Assume there is another polynomial M(X, Y ) with deg X(M) < DX and deg Y (M) < DY such that M(αi, βj) = yij for all 0 i < DX and 0 j < DY . Then the difference M(X, Y ) P(X, Y ) is 0 at DX DY distinct points, (αi, βj) for 0 i < DX, 0 j < DY . And the only polynomial with deg X < DX and deg Y < DY that has DX DY roots is the 0 polynomial. D.1.2 Notation Here we recall notation we will use from [14]. 1. The class of Monarch matrices is defined in appendix C of [14] as M(b,N) which are N N matrices with block size b for any integer 0 b N that divides N. When b = N we drop b from the notation giving (i1, i0) and (j1, j0). For example, this is used in Proof of Corollary 2. 2. Row index i can be represented as (i1, i0)b. Which gives i = i1b + i0. 3. Similarly, column index j can be represented as (j1, j0)b. Which gives j = j1b + j0. Note that when b = N, j1 = k(j) and j0 = m(j). We choose to use the (j1, j0) notation here since that notation is easier to generalize for p > 2. 4. L DB(b,N) is an N N matrix with b b blocks that are all diagonal matrices. 5. R BD(b,N) meaning it s a block diagonal N N matrix with block size b b. 6. We have a class of permutation matrices defined as σ(b,N)(i) = i0 N b + i1. This can be denoted by an N N matrix, P(b,N), where the ith row is eσ(b,N)(i). 7. We ll use i or pair notation (i1, i0)b to denote the rows, and j or pair notation (j1, j0)b to denote columns. It should be clear from context which one we re using. For any 0 j1 < N, let ℓj1(X, Y ) be an arbitrary bivariate polynomial with deg X(ℓj1), deg Y (ℓj1) < For any 0 j1, j0 < N, let rj1,j0(Y ) be an arbitrary univariate polynomial of degree < Let A = (α0, . . . , α N 1), B = (β0, . . . , β N 1) each be a sequence of distinct eval points. Note that A and B need not be disjoint. From the proof of Theorem 3 in the Appendix C of the Monarch paper [14] we get, L = P(b,N) L P (b,N) = P(b,N) L P( N L = P (b,N) L P(b,N) = P( N b ,N) L P(b,N). Define DB and BD as set of all such L and R matrices over RN N where if i0 = k0 Li1,k1[i0, k0] def = L[(i1, i0) N, (k1, k0) N] = 0 (18) and if k1 = j1 Rk1,j1[k0, j0] def = R[(k1, k0) N, (j1, j0) N] = 0. (19) Pictorially, L and R look as follows: In [14], Monarch matrices with block size b = N, M = L R, and thus for all 0 i1, i0, j1, j0 < N: M [(i1, i0) N, (j1, j0) N] = Li1,j1[i0, i0] Rj1,j1[i0, j0]. (20) We note that our definition of Monarch matrix M in Section 3 is slightly different in that M = M P with M as defined in [14]. D.2 Monarch Matrices and Bivariate Polynomial Evaluation Given polynomials ℓj1(X, Y ) for 0 j1 < N, polynomials rj1,j0(Y ) for 0 j1, j0 < N, evaluation points A = (α0, ..., α N 1) B = (β0, , β N 1) (as in Appendix D.1.2), define the matrices L DB N,N and R BD For every 0 j1, i1, i0 < Li1,j1[i0, i0] ℓj1(αi1, βi0). (21) For every 0 j1, j0, i0 < Rj1,j1[i0, j0] rj1,j0(βi0). (22) Note that all entries of L and R not specified above are set to 0. Let f be the above function that maps coefficients of ℓj1(X, Y ) (which are coefficients of monomials Xi1Y i0 for all 0 i1, i0 < N and hence represented by a matrix in R N) and coefficients of rj1,j0(Y ) (which are coefficients of monomials Y i0 for all 0 i0 < N and hence represented by a vector in R N) for all 0 j1, j0 < N to pairs of matrices in DB N,N. Theorem 6. Let f be as defined above. Then f is a bijection. Proof. To prove f is bijection we must show f is one-to-one and f 1 is one-to-one (and exists). To show f is one to one means each set of polynomials coefficients given to f, will output a unique set of matrices (L, R) DB N,N. This follows from (21), (22) and the known fact that polynomial evaluation is a function. Now, to show f 1 exists and is one-to-one, we must show that there is a map from any pair (L, R) DB N,N to unique sets of polynomials, ℓ, r, with parameters as defined in Appendix D.1.2. Further, we need Li1,j1[i0, i0] = ℓj1(αi1, βi0) (23) and Rj1,j1[i0, j0] = rj1,j0(βi0). (24) We will use Theorems 5 and 4 to show the existence of ℓj1 and rj1,j0 polynomials, giving us the mapping from the matrices to unique polynomials. We first show the existence of the unique polynomials in (24). Fix 0 j1, j0 < N. Then consider the values 0 i0 < N: yi0 Rj1,j1[i0, j0]. (25) Then by Theorem 4, there exists a unique polynomial of degree < N (call it rj1,j0(Y )) such that for all 0 i0 < N, rj1,j0(βi0) = yi0, which by (25) shows (24). Next we show the existence of the unique polynomials in (23). Fix 0 j1 < N. Consider the values 0 i1, i0 < N: yi1,i0 Li1,j1[i0, i0]. (26) Then by Theorem 5, there exists a unique bi-variate polynomial of deg X < N and deg Y < N (call it ℓj1(X, Y )) such that for all 0 i1, i0 < ℓj1(αi1, βi0) = yi1,i0, (27) which by (26) shows (23). Therefore f is a bijection. We can now conclude: Corollary 1. For every matrix M as defined in (20), there exists unique polynomials ℓj1(X, Y ) and rj1,j0(Y ), such that for all 0 i1, i0, j1, j0 < M [(i1, i0) N, (j1, j0) N] = ℓj1(αi1, βi0) rj1,j0(βi0). (28) Proof. Follows from (20), (21), (22) and Theorem 6. D.2.1 Proof of Theorem 1 We begin with an immediate consequence of Corollary 1: Corollary 2. Let A, B C such that |A| = |B| = N. Then the jth column of M is the evaluation of the polynomial ℓj1(X, Y ) rj1,j0(Y ) over A B. Proof. Observe that for fixed j0, j1 the right hand side of (28) is ℓj1(X, Y ) rj1,j0(Y ) evaluated at all (α, β) A B. Thus, (j1, j0) column is evaluation of ℓj1(X, Y ) rj1,j0(Y ) over points in A B, as desired. Next, we state a generalization of Theorem 1 that follows from Corollary 2: Corollary 3. Let A and B be as in Corollary 2. For any vector u, M u is u(X, Y ) evaluated at A B. Further, u(X, Y ) = X N uj1,j0 ℓj1(X, Y ) rj1,j0(Y ), (29) where ℓand r are defined by M as in Corollary 1. Proof. Follows from Corollary 2 and definition of matrix vector multiplication. In Theorem 1 and the following sections, we consider the polynomial evaluated over the basis polynomials defined by M = M P Corollary 4. Let A and B be as in Corollary 2. For any vector u, and M u is u(X, Y ) evaluated at A B. Further, u(X, Y ) = X N uj0,j1 ℓj0(X, Y ) rj0,j1(Y ), (30) where ℓand r are defined by M as in Corollary 1. Proof. Follows from Corollary 3 and definition of M. Specifically, Theorem 1 is a special case of Corollary 4 where ℓj0(X, Y ) = ℓm(j)(X, Y ) and rj(Y ) = rj0,j1(Y ). D.2.2 Proof of Theorem 2 By Corollary 4, for all 0 j1, i1 < ℓj0(X, Y ), rj0,j1(Y ), ℓj0(X, Y ), rj0,j1(Y ) are the basis polynomials corresponding to M1 and M2. For the coefficient vector k = (kj1,j0)0 j1,j0< N and similarly for u = (uj1,j0)0 j1,j0< N, we can construct two polynomials k(X, Y ) = X N kj1,j0 ℓj0(X, Y ) rj0,j1(Y ) u(X, Y ) = X N uj1,j0 ℓj0(X, Y ) rj0,j1(Y ) whose evaluation over (αi1, βi0) = (ωi1, ωi0) where recall as in Appendix D.1.1ω = e 2πι N , by Theorem 1 is equivalent to the products M1 k and M2 u, respectively. Taking the component-wise product, y=(M1 k) (M2 u), the entry at i = (i1, i0) is given by y[(i1, i0)] = k(ωi1, ωi0) u(ωi1, ωi0). Noting that the element of A, i.e. the N-th roots of unity, satisfy Z N = 1 means that the above are evaluations of h(X, Y ) = k(X, Y ) u(X, Y ) mod (X at A A. Finally, Theorem 1 and the fact that M 1 0 exists implies M0 y is polynomial interpolation into basis polynomials corresponding to M0. (Here we use the well known fact polynomial interpolation is the inverse of polynomial evaluation). D.3 Proof of Theorem 3 We review some concepts in Appendix D.3.1. In Appendix D.3.2, we discuss square matrices and causality in terms of operations on univariate polynomials. This allows us to define a general class of operators for causal 1D convolution. In Appendix D.3.3, we give a class of matrices suitable for perform causal Monarch convolution. Specifically, we prove Theorem 3. D.3.1 Review Consider the linear operation on an input vector u: We say that the map is causal to mean the entry y[i] only depends on u[0], u[1], . . . u[i]. This will be the case when A is a lower triangular matrix (we index the top left entry of A as (0, 0)). When A is a lower triangular Toeplitz matrix with entries corresponding to some coefficient vector k, this operation is exactly the 1D convolution y = k u = F 1 2n ((F2n k ) (F2n u )) [0 : n 1], where k = (k, 0n), u = (u, 0n), and Fn is the n n DFT matrix. Definition 4. For a matrix M Rn n, let us define the map y = M 1(M k M u) (31) as matrix convolution. When M is a Monarch matrix, (31) is called Monarch convolution. In this section, we are interested in determining large subclasses of matrices M such that for any coefficient vector k, (31) is causal in u. We provide a class of matrices for which Monarch convolution is causal. We note that for general Monarch matrix M, (31) is not causal in u. By Theorem 2, we have y(X, Y ) = k(X, Y ) u(X, Y ) mod (X This is not causal because the mod (X N 1) term condenses higher order terms into lower order terms, hence the y[i] wouldn t just depend on input information up to value i. D.3.2 Univariate Matrix Convolutions We start with a couple of notation assumptions. Assumption 1. N is a perfect square. Assumption 2. We will not use pair notation for this subsection since throughout we have i = i1 N + i0 and j = j + 1 In order to discuss square matrices in terms of univariate polynomials, we give univariate analogs of Theorem 1 and Theorem 2 for general univariate basis. With an eye toward towards performing causal convolution, we restrict our analysis to certain classes of univariate polynomials. We first define matrices whose jth columns are the evaluation of a minimum degree j (and maximum degree N 1) polynomial (recall Definition 1). We generalize Theorem 3 to such matrices. Lemma 1. For sequence of points A = {1, ωN, ωN 1 N } where ωN is the N th root of unity, let M be defined as M[i, j] = qj(ωi N) (32) where qj(Z) is defined as in Definition 1. Then for any vector v RN, M v is equivalent to evaluating the polynomial j=0 vj qj(Z) (33) at {1, ωN, ωN 1 N }. Proof. By our definition of M, the column M[:, j] is exactly the evaluation of the polynomial qj(Z) at each point in A. The claimed result comes from the definition of matrix vector multiplication and (33). Note that M or any M s in this sub-section are not necessarily Monarch matrices. Next, we state the following intermediate result: Proposition 1. Let A be the set of the N-th roots of unity. Then for M1, M2 defined as in (32) y = (M1 k) (M2 u) is the same as evaluating the polynomial p(Z) := k(Z) u(Z) mod (ZN 1) over A where k(Z), u(Z) are of the form (33), corresponding to M1 and M2, respectively. In other words, for any 0 i < N, y[i] = p ωi N . Proof. This result follows from Lemma 1 and the definition of the Hadamard product. Next, we state a re-interpretation of M 1y: Proposition 2. Let M be a full rank matrix whose columns are the evaluations of the basis polynomials qj(Z) from Definition 1 for 0 j < N, and let y RN be an arbitrary vector. If u = M 1y, then for all 0 i < N y[i] = u(ωi) where u(Z) is the same as in Lemma 1 for M. In other words, M 1y is the polynomial interpolaton problem for the polynomial basis qj(Z) for 0 j < N. Proof. This follows from Lemma 1 and the fact that M is invertible. From Propositions 1 and 2, we get the following generalization of Theorem 2: Theorem 7. For matrices M0, M1, M2 as defined above, the operation f = M 1 0 ((M1 k) (M2 u)) is equivalent to representing the polynomial f(Z) = k(Z) u(Z) mod (ZN 1) in terms of the basis polynomials ˆqj(Z) for j = 0, . . . , N 1 where k(Z), u(Z) are defined as in Lemma 1 in terms of the basis polynomials corresponding to M1 and M2, respectively, and (ˆqj(Z))0 j...bn", w, x.view (*x.shape [:-1], w.shape [0], w. shape [-1]) 8 ).reshape (*x.shape) 10 class Monarch Matrix(nn.Module): 12 def __init__(self , sqrt_n: int): 13 super ().__init__ () 14 self.sqrt_n = sqrt_n 15 self.L = nn.Parameter(torch.randn ((sqrt_n , sqrt_n , sqrt_n))) 16 self.R = nn.Parameter(torch.randn ((sqrt_n , sqrt_n , sqrt_n))) 18 def forward(self , x): 19 x = rearrange(x, "... (m n) -> ... (n m)", n=self.sqrt_n) 20 x = blockdiag_matmul (x, self.L) 21 x = rearrange(x, "... (m n) -> ... (n m)", n=self.sqrt_n) 22 x = blockdiag_matmul (x, self.R) 23 return rearrange(x, "... (m n) -> ... (n m)", n=self.sqrt_n) 25 class Monarch Mixer Layer (nn.Module): 26 def __init__(self , sqrt_n: int , sqrt_d: int): 27 super ().__init__ () 28 self.m1 = Monarch Matrix (sqrt_n) 29 self.m2 = Monarch Matrix (sqrt_n) 30 self.m3 = Monarch Matrix (sqrt_d) 31 self.m4 = Monarch Matrix (sqrt_d) 33 self.n_kernel = nn.Parameter(torch.randn(sqrt_d ** 2, sqrt_n ** 2)) 34 self.d_kernel = nn.Parameter(torch.randn (1, sqrt_d ** 2)) 35 self.layer_norm = nn.Layer Norm(sqrt_d ** 2) 37 def forward(self , x: torch.Tensor): # x.shape = (b, n, d) 38 x_tilde = self.m2(self.n_kernel * self.m1(x.transpose (-1, -2)) ).transpose (-1, -2) # mix sequence 39 y = self.m4(torch.relu(self.d_kernel * self.m3(x_tilde))) # mix features 40 return self.layer_norm(y + x_tilde) # skip connection Listing 1: A basic implementation of the M2 layer.