# omninet_omnidirectional_representations_from_transformers__631bb5a8.pdf Omni Net: Omnidirectional Representations from Transformers Yi Tay * 1 Mostafa Dehghani * 2 Vamsi Aribandi 1 3 Jai Gupta 1 Philip Pham 1 Zhen Qin 1 Dara Bahri 1 Da-Cheng Juan 1 Donald Metzler 1 Abstract This paper proposes Omnidirectional Representations from Transformers (OMNINET). In Omni Net, instead of maintaining a strictly horizontal receptive field, each token is allowed to attend to all tokens in the entire network. This process can also be interpreted as a form of extreme or intensive attention mechanism that has the receptive field of the entire width and depth of the network. To this end, the omnidirectional attention is learned via a meta-learner, which is essentially another self-attention based model. In order to mitigate the computationally expensive costs of full receptive field attention, we leverage efficient self-attention models such as kernel-based (Choromanski et al., 2020), low-rank attention (Wang et al., 2020) and/or Big Bird (Zaheer et al., 2020) as the meta-learner. Extensive experiments are conducted on autoregressive language modeling (LM1B, C4), Machine Translation, Long Range Arena (LRA), and Image Recognition. The experiments show that Omni Net achieves considerable improvements across these tasks, including achieving state-of-the-art performance on LM1B, WMT 14 En-De/En-Fr, and Long Range Arena. Moreover, using omnidirectional representation in Vision Transformers leads to significant improvements on image recognition tasks on both few-shot learning and fine-tuning setups. 1. Introduction Transformers (Vaswani et al., 2017), characterized by stacked self-attention modules and feed-forward transformations, have become a staple in modern deep learning, natural language processing (Devlin et al., 2018; Raffel et al., 2019) and even computer vision (Dosovitskiy et al., 2020). One *Equal contribution 1Google Research, Mountain View 2Google Brain Team, Amsterdam 3Google AI Resident. Correspondence to: Yi Tay , Mostafa Dehghani . Proceedings of the 38 th International Conference on Machine Learning, PMLR 139, 2021. Copyright 2021 by the author(s). key defining characteristics in the self-attention mechanism is the global receptive field in which each token is accessible to every other token in the sequence, serving as an enabler for learning global contextual representations. This paper proposes learning omnidirectional representations from transformers. The key idea is to move beyond horizontally global receptive fields and explore the possibility of omnidirectional receptive fields. In short, we allow each token to not only attend to all other tokens in the same layer, but also all token in all the layers of the network. This global access enables tokens to have a full view of the network and as a result access the knowledge and intermediate representations of every token at each layer. By modeling the relationships amongst tokens of different hierarchical levels, we are also able to capture patterns pertaining to the propagation of representations across time. Finally, this approach can be also be interpreted as a form of dense residual connection (Huang et al., 2017), which has been shown to be beneficial by aiding gradient flow. Learning omnidirectional receptive fields is non-trivial for two key reasons. Firstly, given the quadratic complexity of the scaled dot product attention, the complexity of designing such a receptive field is increased from N 2L to (NL)2, where L is the depth of the network and N is the sequence length. We postulate that this is one big challenge that has prohibited this type of architecture from being explored in the past. Secondly, simply enabling omnidirectional attention from the get-go would easily cause a degeneration of the transformer into a flat network, losing much of its representational power that is enabled by sequentially refining its representations across network layers. To mitigate these issues, our omnidirectional attention is implemented as a form of meta-learner that acts upon a standard transformer model. The meta-learner is yet another self-attention model that accepts all hidden representations across all layers as an input and refines them based on all the available information. In order to mitigate the prohibitive memory and computational costs of omnidirectional attention, we explore and evaluate multiple efficient alternatives of parameterizing the meta-learner, e.g., including fast attention via generalizable kernel attention (Choromanski et al., 2020), low-rank self-attention (Wang et al., 2020), Omni Net: Omnidirectional Representations from Transformers and/or block-based sparsity (Zaheer et al., 2020). Additionally, we further hypothesize that employing methods that try to learn the low-rank factorized structure of the entire network can lead to improved generalization capabilities - as demonstrated in our few-shot learning experiments. Aside from varying the parameterization of the meta-learner, we also introduce partitioned variants of Omni Net in which the meta-learner is applied to blocks of p consecutive layers. In short, this partitioning strategy groups the full network of L layers into L p partitions. After computing each partition, the meta-learner learns the omnidirectional attention of all nodes across all layers in the partition. Via extensive experiments, we show that Omni Net achieves very promising results on a myriad of language, vision, and logic tasks. Specifically, we report strong experimental results on autoregressive language modeling (Chelba et al., 2013; Raffel et al., 2019), five collections of WMT machine translation, Long Range Arena (Tay et al., 2020a), and Image Recognition using Vision Transformers (Dosovitskiy et al., 2020). Moreover, we systematically evaluate Omni Nets through the lens of the performance-compute trade-off and show that they are pareto-optimal in this regard. On machine translation, Omni Net outperforms ADMIN (Liu et al., 2020), the current state-of-the-art 60 layers deep transformer model on two well-established machine translation collections (WMT 14 English-German and WMT 14 English-French). On the one billion language modeling benchmark, Omni Net outperforms existing state-of-the-art models such as Transformer-XL (Dai et al., 2019). On LRA, Omni Net improves aggregate performance over Performers (Choromanski et al., 2020) by +8.9% and vanilla Transformers by +2.6%. On Image Recognition tasks, Omni Net demonstrates stellar few-shot learning and finetuning performance, outperforming Vi T (Dosovitskiy et al., 2020) by up to +3% on both finetuning and few-shot learning experiments. 2. Related Work Just across the past several years, attention mechanisms (Bahdanau et al., 2014) have made a significant impact on machine learning research (Vaswani et al., 2017; Devlin et al., 2018; Dosovitskiy et al., 2020; Raffel et al., 2019; Brown et al., 2020; Dehghani et al., 2018). Simply speaking, these parameterized pooling mechanisms learn to align representations and route information based on the notion of relative importance. While early work in this area was mainly concerned with learning an alignment function between two or more sequences (Bahdanau et al., 2014; Parikh et al., 2016), there have been more focus on self-alignment (e.g., self-attention) in the recent research climate (Vaswani et al., 2017). Attention mechanisms are generally applied layer-wise and operate across a one-dimensional sequence. Attention is generally bidirectional, or unidirectional in the case where a token is to be denied access to future tokens. There have been early attempts to mix information across layers in pursuit of improving gradient flow and model trainability. For example, (Bapna et al., 2018) proposed transparent attention in which the decoder gains access to all encoder layers. (He et al., 2018) proposed layer-wise coordination between encoder-decoder for machine translation. (Tay et al., 2018) proposed to densely connect the attention across stacked RNN encoder layers for language understanding tasks. The recent Realformer (residual attention) (He et al., 2020) proposed connecting the attention activations in a residual fashion. We believe there is sufficient evidence in the literature to suggest that mixing representations across layers is beneficial. This is further supported by fundamental work in this area such as Res Nets (He et al., 2016), highway networks (Srivastava et al., 2015) and Dense Nets (Huang et al., 2017). In this paper, we are mainly interested in methods for efficiently learning omnidirectional attention - an attention over the entire width and depth of the network. To this end, we leverage the recent advances in making transformers fast and efficient (Zaheer et al., 2020; Choromanski et al., 2020; Wang et al., 2020). Many of these approaches learn an approximation via low-rank projection, kernels or block-based sparsity. An overview and extensive empirical comparison can be found at (Tay et al., 2020b;a). To this end, the proposed approach leverages these recent advances to make what was previously not possible. By leveraging fast and efficient self-attention, we enable scalable and powerful omnidirectional attention. 3. The Proposed Method This section introduces Omni Net. We first begin by reviewing the standard Transformer architecture. 3.1. Transformer Architectures This section provides a brief background for the Transformer architecture. The Transformer block accepts N d input, where N denotes the number of tokens in the sequence and d denotes the size of the representation. Each Transformer block is characterized by a self-attention block and a two layered feed-forward network with Re LU activations in-between that is applied position-wise. 3.1.1. SELF-ATTENTION The self-attention mechanism first projects each input X into Q,K,V representations using linear transformations, corresponding to queries, keys, and values. The self-attention mechanism is typically multi-headed where multiple similar linear projections are executed in parallel. The output of Omni Net: Omnidirectional Representations from Transformers Omnidirectional Attention Standard Layers Omnidirectional Layer pooling pooling pooling pooling ... Omnidirectional Representations Figure 1. Overview of Omni Net. In the diagram, the omnidirectional module, when partition size is P = L, combines the information across all positions (1:N), across all layers (1:L 1), and for each position selects the best of all layers via a pooling operation to generate the final representations. each self-attention head h at layer l is written as: yh,l =softmax Qh,l K h,l dk where yh,l is the output of head h at layer l and dk is the size of each head. The output from the multiple heads is then concatenated and then passed through another linear transformation via Wo,l which projects the concatenation of all heads down to dm. This is wrapped via a layer normalization followed by a residual connection and can be written as: Layer Norm(Wo,lconcat([y1,l y H,l)))+xl 1 as the final output of the self-attention module. Feed Forward Layers The FFN block of the Transformer block performs a two layer transformation defined as: zl =Layer Norm(W1,l Re LU(W2,l(Y )))+zl 1, (2) where W1,W2 are trainable parameters (weight transforms) of the FFN layer. Bias parameters are omitted for clarity. 3.2. Omni Net The proposed Omni Net method operates on an arbitrary multi-layered architecture that accepts sequential inputs. In our description, this typically refers to a stacked X-former architecture in this section. Note that while this is typically a transformer model, it can also be an arbitrary variant (Choromanski et al., 2020; Wang et al., 2020). Figure 1 illustrates a brief overview of the proposed Omni Net architecture. 3.2.1. OMNIDIRECTIONAL REPRESENTATIONS In a stacked network of L layers, each layer exposes a sequence of N vectors of d dimensions each. Specifically, Omni Net operates across all layers and connects the multi-layered network architecture in a grid like fashion. We describe the network as xformer which accepts X as an input and returns a tensor of L N d dimensions. xformer(X)=X1,X2 XL, (3) where Xi RN d. Let Xi j be the representation of X at layer i and position j of the sequence. The Omni Net mechanism can be written as: O=Attend(Index Sort(X1,X2, XL)), (4) where Attend denotes an arbitrary self-attention block. The Index Sort operation takes X1,X2,XL and sorts,1 tokens within each matrix by index such that the adjacent token of the ith token in layer l are the ith token from l 1 and l+1 respectively. Next, given that the input sequence length is LN, it is advantageous for Attend to be as efficient as possible. We describe three variants of Omni Net s core linear-time self-attention mechanism in subsequent sections. Given O R(L N) d, the output of the omnidirectional attention, we perform P(.) a pooling operator. While there are many choices of pooling operators, parameterized or otherwise, we adopt a simple pooling function - a max pooling of stride L. O =Max Pool1D(O), (5) where O RN d. Given O , the final representation of an Omni Net augmented network is defined as: Omni Net(X)=xformer(X)L+O . (6) The Omni Net and main transformer model are trained together in an end-to-end fashion, i.e., gradients flow to both networks concurrently at each backward pass. 3.2.2. MAINTAINING CAUSALITY AND AUTOREGRESSIVE DECODING A key point to note with Index Sort is that this order enables us to apply a causal mask to the Attend function, namely if tokens are sorted according to sequence index first as opposed to layer first, then it would be easy to apply a causal mask M, where M[i,j]=0 when i j and inf when i>j. This enables Omni Net to be used in autoregressive settings. 1Since attention is permutation invariant this sorting simply makes it easier to (1) compute casual masks and (2) aggregate representations index-wise. Omni Net: Omnidirectional Representations from Transformers 3.2.3. EFFICIENT TRANSFORMERS We describe several choices of linear-time self-attention mechanisms that are used in Omni Net s omnidirectional attention. Generally, Attend refers to an attention block with an attention function and a two-layered positional FFN in a similar structure to the transformer backbone. For the sake of brevity, we only describe the core attention mechanism here. Our choice of the efficient transformer is informed by (Tay et al., 2020a) selecting models that perform well on the compute-performance trade-off. For a list of potential variants to adopt, we refer readers to (Tay et al., 2020b). Kernel-based This variant uses the generalizable kernel attention, the fast attention mechanism proposed in Performer (Choromanski et al., 2020). Specifically, this is written as: o=Woconcat( ˆ Dh 1(φ(Qh)(φ(Kh)) Vh)), where ˆ Dh =diagφ(Qh)((φ(Kh)) 1L) and φ(.) is a random feature map that projects Rd to Rr. We refer readers to (Choromanski et al., 2020) for more details. Low-rank Inspired by Linformer s (Wang et al., 2020) self-attention mechanism, we set Attend to be: o=Wo(concat(softmax Qh(WKh) where W RN k are low-rank projection transformations that are shared across heads and across keys and values. The complexity of this self-attention mechanism is Nk instead of N 2, where k N. Block and Memory based Lastly, we also explore a block and memory-based variant of efficient Transformers based on Big Bird (Zaheer et al., 2020). In short, this is a combination of windowed attention, global attention, and sparse attention. The output for token i is defined as: h=1 softmax Qh,i K h,N(i) Vh,i, where N(i) is the neighborhood function which denotes the out-neighbors of node i, H is the total number of heads and h represents a head. The neighborhood function is mainly dependent on the width of the windowed attention. We refer the reader to (Zaheer et al., 2020) for more details. 3.2.4. PARTITIONED OMNINET This section describes the types of partitioning variants that we explore in Omni Net. When L is large, the eventual representation input to Omni Net can be extremely large.2 2A sequence length of 1K would result in a 11K input sequence length for a 12 layered Transformer model, when using an omnidirectional layer as the final layer. Table 1. Experimental results (quality, i.e., perplexity scores at 30K and 100K respectively) on autoregressive language modeling. All models are approximately 50M parameters. Model LM1B C4 Transformer 33.14 34.86 Realformer 32.95 35.63 Performer 34.33 35.68 Big Bird 32.90 38.36 Omni Net B 33.69 (-1.7%) 34.73 (+0.4%) Omni Net P 30.19 (+9.0%) 33.97 (+2.6%) Omni Net T 30.12 (+9.1%) 33.39 (+4.2%) Table 2. Comparison with existing state-of-the-art and published works on One Billion Word Language modeling (Chelba et al., 2013) benchmark. Model #Params PPL Adaptive Input (Baevski & Auli) 0.5B 24.1 Adaptive Input (Baevski & Auli) 1.0B 23.7 Transformer-XL (Dai et al.) 0.5B 23.5 Transformer-XL (Dai et al.) 0.8B 21.8 Omni Net P (Large) 0.1B 21.6 Omni Net B (Large) 0.1B 22.0 Omni Net T (Large) 0.1B 21.5 Let P be an integer valued hyperparameter that determines the partition size. For a L layer transformer network, when ℓmod P is 0, we insert a meta-learner block. ( Attend(Xℓ P , Xℓ 1)), if ℓ mod P =0 xformer(Xℓ 1) In short, whenever ℓmod P =0, we activate an omnidirectional attention layer, aggregating representations all the way from the previous partition ℓ P layer up till ℓ 1. In this case, we skip the original xformer layer, hence maintaining approximately the same parameter size of the network. 4. Experiments We conduct experiments on autoregressive language modeling, machine translation, long range sequence modeling and a series of image recognition tasks. Our implementation uses Flax (Heek et al., 2020) and Jax (Bradbury et al., 2018). 4.1. Autoregressive Language Modeling We run experiments on large-scale autoregressive (unidirectional) language modeling. We use two large-scale datasets, language modeling one billion (LM1B) (Chelba et al., 2013) and the Colossal Cleaned Common Crawl corpus (C4) (Raffel et al., 2019). Omni Net: Omnidirectional Representations from Transformers Table 3. Results on five collections from the WMT 17 machine translation task. Model En-De En-Fi Cs-En En-Fr Ru-En Transformer. 28.6 20.5 22.2 43.0 35.8 Omni Net L 28.8 (+0.7%) 20.8 (+1.5%) 22.8 (+2.7%) 43.3 (+0.7%) 36.2 (+1.1%) Omni Net B 28.8 (+0.7%) 20.9 (+2.0%) 22.6 (+1.8%) 43.2 (+0.5%) 34.2 (-4.5%) Omni Net P 29.0 (+1.4%) 20.9 (+2.0%) 23.0 (+3.6%) 43.1 (+0.2%) 36.2 (+1.1%) Table 4. Comparisons with the state-of-the-art on WMT 14 En-De and WMT 14 En-Fr. Omni Net outperforms ADMIN (Liu et al., 2020), the current state-of-the-art deep transformer model for MT. Model En-De En-Fr Evolved Trans. (So et al., 2019) 29.2 n/a Large Trans. (Ott et al., 2018) 28.6 41.4 60L Trans. (Liu et al., 2020) 29.5 41.8 Omni Net P 29.8 42.6 4.1.1. EXPERIMENTAL SETUP For both tasks, we use a max length of 256 subword tokens per example and report scores on subword perplexity on the validation set. In the first ablative experiment, we train all models for 30K for LM1b and 100K steps for C4 using 16 TPU-V3 Chips. Models are of base size and have an embedding dimension of 512, 8 heads, 6 layers and hidden dimensions (MLP) of 2048. For strong baselines, we compare with Transformers, Performers (Choromanski et al., 2020), and Big Bird (Zaheer et al., 2020). We also add the recent Realformer (residual attention Transformer) (He et al., 2020) as a strong baseline. For Omni Net, we tune the partition amongst {2,3,6}. All models have approximately 50M parameters. Next, we are interested in (1) how Omni Net scales to large sizes and (2) comparing with other published works (Dai et al., 2019). Hence, we implement a larger Omni Net with MLPs of size 8K and head size of 2K. 4.1.2. RESULTS ON LANGUAGE MODELING Table 1 reports results on LM. We observe that Omni Net P,T outperforms all baselines by about +9.1% on LM1b and +4.2% on C4. We also outperform strong baselines such as Realformer, Big Bird, and vanilla Transformers on both corpora. We also observe that Omni Net P performs reasonably close to Omni Net T , which ascertains that using an efficient Transformer may be sufficient for omnidirectional attention. On the other hand, Table 2 reports a comparison with other published works on LM1B. Notably, Omni Net P,T (large) outperforms Transformer-XL and achieves state-of-the-art performance. 4.2. Neural Machine Translation We conduct experiments on machine translation, a sequenceto-sequence task. for evaluating Transformer models. 4.2.1. EXPERIMENTAL SETUP We use five collections/datasets from WMT-17,3 namely En De (English German), En-Cs (English Czech), En-Fi (English Finnish), En-Fr (English French) and En-Ru (English Russian). We train all models using 16 TPU-V3 chips with a batch size of 256. Our base Transformer model has 6 layers, a hidden size of 4096, embedding size of 1024, and a head size of 1024. The number of heads is 16. We use a Sentence Piece (Kudo & Richardson, 2018) vocabulary of 32K built for each language specifically. More details can be found in the appendix. 4.2.2. RESULTS ON WMT 17 MACHINE TRANSLATION Table 3 reports results on all 5 collections of WMT 17. Overall, Omni Net P outperforms the vanilla Transformer model on all five collections, with up to +3.6% performance improvement. Similar to LM, we find that the performer variant works the best and the Big Bird variant works the worse. 4.2.3. COMPARISONS AGAINST STATE-OF-THE-ART We train a large Omni Net model and compare it with the state-of-the-art approaches. We compare with ADMIN (Liu et al., 2020), a very deep (60 layers) Transformer model that achieves state-of-the-art performance on the WMT En-De dataset. We use a 8 layer Omni Net model with 4096 MLP dimensions, 1024 hidden dimensions and embedding dimensions. We compare models using sacrebleu (Post, 2018). For Omni Net, given the strong performance of the Performer variant on WMT 17 collections, we only train a single P variant Omni Net for comparing with SOTA models. Further details of the setup can be found in the appendix. Table 4 reports results on WMT 14 En-De and En-Fr. Our results show that Omni Net outperforms the existing state-of-the-art ADMIN model (Liu et al., 2020), a 60-layer deep transformer model. 4.3. Long Range Arena We conduct experiments on the recently proposed Long Range Arena benchmark (Tay et al., 2020a). The goal of this experiment is to show that Omni Net improves long-range se- 3http://www.statmt.org/wmt17/ translation-task.html Omni Net: Omnidirectional Representations from Transformers Table 5. Results on Long Range Arena (Tay et al., 2020a). Model Text Retrieval List Ops Avg Linformer 53.9 52.3 35.7 47.3 Big Bird 64.0 54.7 36.1 51.6 Performer 65.4 53.8 18.0 45.7 +Omni Net P 65.6 60.9 18.2 48.2 +Omni Net L 63.1 63.7 37.1 54.6 Transformer 62.5 57.5 36.4 52.1 +Omni Net P 65.1 58.8 37.2 53.7 +Omni Net L 63.1 63.8 37.2 54.7 quence modeling. A dual goal is to show that it is possible to combine different inductive biases to obtain a better efficient Transformer model that is versatile on different types of data. 4.3.1. EXPERIMENTAL SETUP We run two key experiments using Transformer and Performer as the main backbone model and vary the metalearner in Omni Net, using Linformer and Performer variants of the Omni Net meta-learner. The goal is to demonstrate that Omni Net translates to backbone agnostic improvements. We run Omni Net experiments using the LRA codebase and run Omni Net models using the same hyperparameters as the results reported in (Tay et al., 2020a). Note that LRA is comprised of five benchmarks, however, we omit Image and Pathfinder experiments since the best hyperparameters on these tasks turn out to be shallow (1-2 layered) models. We report the average of the text, retrieval, and List Ops tasks. 4.3.2. RESULTS ON LRA Table 5 reports the results on our LRA experiments. Firstly, we observe that Omni Net makes substantial improvements to the base model, regardless of whether it is a Transformer or Performer. Notably, with Omni Net L, the Linformer meta-learner, the Performer model is improved by almost 6 to 7 absolute percentage points. An interesting observation can be made on the List Ops task where Omninet P (Performer variant) does not result in much improvement over the base Performer. However, the performance doubles with Omni Net L. Since the base Linformer model does pretty well on this task, we postulate that this is due to Omni Net L providing a Linformer-like inductive bias to the Performer model. Finally, we observe that Omni Net improves the vanilla Transformer in both cases (P or L), improving the average score by about +2.6% absolute percentage points. 4.4. Image Recognition Transformer-based models started showing competitive performance on different vision tasks like classification, object detection, and segmentation (Chen et al., 2020; Dosovitskiy et al., 2020; Carion et al., 2020; Kumar et al., 2021). To showcase the power of omnidirectional representations in yet another task, we incorporate the omnidirectional representation in Vision Transformer (Vi T) (Dosovitskiy et al., 2020), when pre-trained on a large amount of data in a supervised fashion and evaluated on downstream image recognition tasks, either through few-shot learning or fine-tuning. 4.4.1. VISION TRANSFORMER Vision Transformers (Vi T) (Dosovitskiy et al., 2020) have recently shown impressive results on image classification compared to state-of-the-art convolutional networks, while they require significantly fewer computational resources to train. Vi T is a standard Transformer that is directly applied to images. To do so, we first split the input images into nonoverlapping patches and embedded them using a linear projection. The patch embeddings are provided as a sequence of tokens to a Transformer. When pre-trained on large datasets (14M-300Mimages)atasufficientscale, Vi Tshowsexcellent results that are transferable to tasks with fewer data points. 4.4.2. EXPERIMENTAL SETUP Similar to the Vi T setup, we pre-train our Omni Net models on the JFT dataset (Sun et al., 2017) with 18k classes and 303M images, for 7 epochs. We evaluate our models in the transfer setup (few-shot and fine-tuning) on several downstream tasks: Image Net, CIFAR-10, CIFAR-100 (Krizhevsky et al., 2009), Oxford-IIIT Pets (Parkhi et al., 2012), and Oxford Flowers-102 (Nilsback & Zisserman, 2008). We follow the pre-processing from (Kolesnikov et al., 2019) on both upstream and downstream datasets, which is used in the original Vi T experiments. In our experiments, we train and evaluate Omni Net B/32 and Omni Net B/16, which are based on Vi TB/32 and Vi TB/16.4 Similar to Vi TB/32 and Vi TB/16, Omni Net B/32 and Omni Net B/16 are both base models, adopted from BERT, and use patch sizes of 32 32 and 16 16 respectively. In our Omni Net models, we replace the final layer of Vi T with an omnidirectional layer. In other words, we set the portion size P = 12. In this task, we limit our experiments to using Performers (Choromanski et al., 2020) in the omnidirectional attention, given their strong results among the efficient transformer variants. During pre-training, we use a batch size of 4096 using Adam with β1 =0.9 and β2 =0.999, and use a weight decay of 0.05 for Omni Net. We use a learning rate of 8e 4 with a linear decay and a linear warmup of 10K steps. 4Note that SOTA results on the downstream tasks we use here are from Vi TH/14 (Dosovitskiy et al., 2020), which has more than seven times as many parameters than the base models we use as baselines. Here, we aim at merely showcasing the gain of using omnidirectional representations in the image recognition task. Omni Net: Omnidirectional Representations from Transformers Table 6. Transfer performance of pre-trained Omni Net and equivalent Vi T models in fine-tuning setup on popular image classification benchmarks. All models are pre-trained on the JFT-300M dataset and fine-tuned on the target dataset. Vi TB/32 Omni Net B/32 Vi TB/16 Omni Net B/16 Image Net 0.8073 0.8374 0.8415 0.8626 CIFAR-10 0.9861 0.9900 0.9900 0.9940 CIFAR-100 0.9049 0.9153 0.9186 0.9224 Oxford-IIIT Pets 0.9340 0.9441 0.9580 0.9674 Oxford Flowers-102 0.9927 0.9954 0.9956 0.9961 exa FLOPs 165 193 743 891 Figure 2. Performance of pre-trained Omni Net and equivalent Vi T models in few-shot learning setup on downstream tasks, when transferred using only few images (1, 5, 10, and 25) per class. 4.4.3. RESULTS ON IMAGE RECOGNITION We first present the results of Omni Net and corresponding Vi T models as baselines in the fine-tuning setup. For fine-tuning, we use SGD with momentum and a batch size 512 in all downstream tasks. Table 6 presents the results of fine-tuning experiments. We also report the total pre-training compute, in terms of number of FLOPs for each model. As we can see, introducing a module that learns omnidirectional representations to Vision Transformers leads to improvements on different downstream tasks. Given these improvements and comparing the number of FLOPs for Omni Nets and Vi Ts, we can see that the additional compute, thanks to efficient attention techniques, is fairly reasonable. For evaluating Omni Net in the few-shot learning setup, similar to Vi T, we train a linear classifier on top of the representations from the frozen pre-trained model, given only a subset of training examples. Plots in Figure 2 illustrate the accuracy of Omni Net and Vi T, using different numbers of shots. The results indicate that adding omnidirectional representations to Vi T leads to better transfer across all downstream datasets. 4.5. Effect of Partition Size and Compute/Performance Trade-offs Omni Net offers the flexibility to apply the Omnidirectional layers on different partition sizes. With smaller partition sizes, we attend to tokens from fewer layers, and with Figure 3. Performance of Vi T and Omni Net (with different partition sizes) in terms of top-1 accuracy on Image Net 5-shot linear, versus their computational costs in terms of number of FLOPs. bigger partition, we widen the vertical receptive field of the omnidirectional attention, which might be effective for learning better representations by capturing information from more levels. In terms of computational costs, however, there is a trade-off when choosing the partition size. Small partition sizes means running attention on smaller sequences while repeating it more frequent, and bigger partition sizes means dealing with longer sequences, but having fewer omnidirectional layers in the network. We train Omni Net B/32 and Omni Net B/16 with different partition sizes: P = {1,2,4,6,12}. Partition size P = 1 is simply having no vertical attention and it is just replacing Omni Net: Omnidirectional Representations from Transformers Attention Maps Pooling Statistics input Layer 1 Layer 2 Layer 3 Layer 4 Layer 5 Layer 6 Layer 7 Layer8 Layer 9 Layer 10 Layer 11 Figure 4. Contribution of different layers in Omnidirectional representations for a given set of examples. On top, we plot the omnidirectional attention maps (using Omni Net B/16-P12 ) of one of the heads, over all layers, when CLS token in the last layer is used as query. On the bottom, we show the contribution of each layer to the pooling operation of the Omnidirectional module. normal attention in Vi T, with Performer. We compare these models in terms of their linear 5-shot accuracy on Image Net dataset (similar to the ablation studies in (Dosovitskiy et al., 2020)). Figure 3 presents the performance of each model as well as their computational cost during pre-training. A few patterns can be observed. For both Omni Net B/32 and Omni Net B/16, the power of omnidirectional directional representations kicks in when we work with partition sizes of more than 2. The input resolution during pre-training is 224 224, so for /32 and /16 models the input sequence length to the model is 49 and 196. So when setting P = 1 or P = 2, with such sequence lengths, when using an efficient attention engine, like Performer, which provides an approximation of the dot-product-attention, we do not gain a lot on the speed and we lose a bit of performance. However, when using a larger partition size, the additional compute with respect to the performance gain becomes reasonable. In both /32 and /16, the computation cost is almost similar for P =4 and P =6. With P =4, we have three omnidirectional attention, each applied on 4 layers, while with P =6 we have two omnidirectional attention, each applied on 6 layers. However, P =6 gives slightly better results in terms of accuracy and is placed on a sweeter spot in this trade-off. With P =12, the computational costs of Omni Net increase, but the gain in the performance helps the model to be on the frontier of the compute-performance trade-off, when it is compared to Omni Net B/32 and Omni Net B/16. 4.6. Visualization Omni Net combines information from different layers via two sequential mechanisms ( 3.2.1): (1) omnidirectional attention, where representations of all tokens in all layers get updated with respect to each other using an efficient attention mechanism; and (2) a pooling operation, where for each token, we collect the best values from all layers. In order to understand how these two mechanisms combine information across different layers, we visualize attention maps (Abnar & Zuidema, 2020) and pooling statistics for a set of examples in the image recognition task. Figure 4 depicts three example inputs, where we show how Omni Net attends to different layers, as well as each layer s contribution during the pooling operation. We can see that in some layers, attention seems to detect the objects in the image via attending to the edges or specific parts of the object, while in other layers, the attention mechanism uses mostly background information. It is clear that omnidirectional attention does indeed use such information by actively attending to layers of varying depth. Additionally, when performing the element-wise pool operation over all the layers for each token, only a fraction of values from each layer s representation make it to the final representation. The bottom rows in Figure 4 illustrate this fraction for each token (image patch) across different layers. In most examples, we observe that a majority of the representation after the pooling operation comes from the first few layers. This is further evidence of how Omni Net can provide an explicit path for directing fine-grained information that is captured by the early layers to the final output, leading to much richer Omni Net: Omnidirectional Representations from Transformers representations. For the sake of brevity, we refer readers to the Appendix for more detailed plots for these examples as well as other examples, which illustrate the same trends. 5. Conclusion In this paper, we proposed Omni Net, which uses omnidirectional attention to connect all tokens across the entire network via self-attention. In order to manage the computational costs of the full receptive field, the meta-learner in Omni Net is parameterized by fast and efficient self-attention models. The proposed method achieves stellar performance on a myriad of language and vision tasks. Concretely, Omni Net achieves state-of-the-art performance on WMT En De and En Fr, outperforming deep 60-layer transformers. Omni Net also demonstrates substantial improvement over Vi T on image recognition tasks. Abnar, S. and Zuidema, W. Quantifying attention flow in transformers. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, 2020. Baevski, A. and Auli, M. Adaptive input representations for neural language modeling. ar Xiv preprint ar Xiv:1809.10853, 2018. Bahdanau, D., Cho, K., and Bengio, Y. Neural machine translation by jointly learning to align and translate. ar Xiv preprint ar Xiv:1409.0473, 2014. Bapna, A., Chen, M. X., Firat, O., Cao, Y., and Wu, Y. Training deeper neural machine translation models with transparent attention. ar Xiv preprint ar Xiv:1808.07561, 2018. Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., Vander Plas, J., Wanderman-Milne, S., and Zhang, Q. JAX: composable transformations of Python+Num Py programs, 2018. URL http://github.com/google/jax. Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. ar Xiv preprint ar Xiv:2005.14165, 2020. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., and Zagoruyko, S. End-to-end object detection with transformers. ar Xiv preprint ar Xiv:2005.12872, 2020. Chelba, C., Mikolov, T., Schuster, M., Ge, Q., Brants, T., Koehn, P., and Robinson, T. One billion word benchmark for measuring progress in statistical language modeling. ar Xiv preprint ar Xiv:1312.3005, 2013. Chen, M., Radford, A., Child, R., Wu, J., Jun, H., Luan, D., and Sutskever, I. Generative pretraining from pixels. In International Conference on Machine Learning, pp. 1691 1703. PMLR, 2020. Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L., et al. Rethinking attention with performers. ar Xiv preprint ar Xiv:2009.14794, 2020. Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q. V., and Salakhutdinov, R. Transformer-xl: Attentive language models beyond a fixed-length context. ar Xiv preprint ar Xiv:1901.02860, 2019. Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J., and Kaiser, Ł. Universal transformers. ar Xiv preprint ar Xiv:1807.03819, 2018. Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. ar Xiv preprint ar Xiv:1810.04805, 2018. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., et al. An image is worth 16x16 words: Transformers for image recognition at scale. ar Xiv preprint ar Xiv:2010.11929, 2020. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770 778, 2016. He, R., Ravula, A., Kanagal, B., and Ainslie, J. Realformer: Transformer likes residual attention. ar Xiv e-prints, pp. ar Xiv 2012, 2020. He, T., Tan, X., Xia, Y., He, D., Qin, T., Chen, Z., and Liu, T.-Y. Layer-wise coordination between encoder and decoder for neural machine translation. In Proceedings of the 32Nd International Conference on Neural Information Processing Systems, pp. 7955 7965, 2018. Heek, J., Levskaya, A., Oliver, A., Ritter, M., Rondepierre, B., Steiner, A., and van Zee, M. Flax: A neural network library and ecosystem for JAX, 2020. URL http://github.com/google/flax. Huang, G., Liu, Z., Van Der Maaten, L., and Weinberger, K. Q. Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4700 4708, 2017. Kolesnikov, A., Beyer, L., Zhai, X., Puigcerver, J., Yung, J., Gelly, S., and Houlsby, N. Big transfer (bit): General visual representation learning. ar Xiv preprint ar Xiv:1912.11370, 2019. Omni Net: Omnidirectional Representations from Transformers Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. 2009. Kudo, T. and Richardson, J. Sentencepiece: A simple and language independent subword tokenizer and detokenizer for neural text processing. ar Xiv preprint ar Xiv:1808.06226, 2018. Kumar, M., Weissenborn, D., and Kalchbrenner, N. Colorization transformer. ar Xiv preprint ar Xiv:2102.04432, 2021. Langley, P. Crafting papers on machine learning. In Langley, P. (ed.), Proceedings of the 17th International Conference on Machine Learning (ICML 2000), pp. 1207 1216, Stanford, CA, 2000. Morgan Kaufmann. Liu, X., Duh, K., Liu, L., and Gao, J. Very deep transformers for neural machine translation. ar Xiv preprint ar Xiv:2008.07772, 2020. Nilsback, M.-E. and Zisserman, A. Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, pp. 722 729. IEEE, 2008. Ott, M., Edunov, S., Grangier, D., and Auli, M. Scaling neural machine translation. In Proceedings of the Third Conference on Machine Translation: Research Papers, pp. 1 9, Brussels, Belgium, October 2018. Association for Computational Linguistics. doi: 10.18653/v1/W18-6301. URL https: //www.aclweb.org/anthology/W18-6301. Parikh, A. P., T ackstr om, O., Das, D., and Uszkoreit, J. A decomposable attention model for natural language inference. ar Xiv preprint ar Xiv:1606.01933, 2016. Parkhi, O. M., Vedaldi, A., Zisserman, A., and Jawahar, C. Cats and dogs. In 2012 IEEE conference on computer vision and pattern recognition, pp. 3498 3505. IEEE, 2012. Post, M. A call for clarity in reporting bleu scores. ar Xiv preprint ar Xiv:1804.08771, 2018. Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., and Liu, P. J. Exploring the limits of transfer learning with a unified text-to-text transformer. ar Xiv preprint ar Xiv:1910.10683, 2019. So, D. R., Liang, C., and Le, Q. V. The evolved transformer. ar Xiv preprint ar Xiv:1901.11117, 2019. Srivastava, R. K., Greff, K., and Schmidhuber, J. Highway networks. ar Xiv preprint ar Xiv:1505.00387, 2015. Sun, C., Shrivastava, A., Singh, S., and Gupta, A. Revisiting unreasonable effectiveness of data in deep learning era. In Proceedings of the IEEE international conference on computer vision, 2017. Tay, Y., Tuan, L. A., Hui, S. C., and Su, J. Densely connected attention propagation for reading comprehension. ar Xiv preprint ar Xiv:1811.04210, 2018. Tay, Y., Dehghani, M., Abnar, S., Shen, Y., Bahri, D., Pham, P., Rao, J., Yang, L., Ruder, S., and Metzler, D. Long range arena: A benchmark for efficient transformers. ar Xiv preprint ar Xiv:2011.04006, 2020a. Tay, Y., Dehghani, M., Bahri, D., and Metzler, D. Efficient transformers: A survey. ar Xiv preprint ar Xiv:2009.06732, 2020b. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. In Advances in neural information processing systems, pp. 5998 6008, 2017. Wang, S., Li, B., Khabsa, M., Fang, H., and Ma, H. Linformer: Self-attention with linear complexity. ar Xiv preprint ar Xiv:2006.04768, 2020. Zaheer, M., Guruganesh, G., Dubey, A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., et al. Big bird: Transformers for longer sequences. ar Xiv preprint ar Xiv:2007.14062, 2020.