# gret_global_representation_enhanced_transformer__5507df21.pdf The Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI-20) GRET: Global Representation Enhanced Transformer Rongxiang Weng,1,2 Haoran Wei,2 Shujian Huang,1 Heng Yu,2 Lidong Bing,2 Weihua Luo,2 Jiajun Chen1 1National Key Laboratory for Novel Software Technology, Nanjing University, Nanjing, China 2Machine Intelligence Technology Lab, Alibaba Group, Hangzhou, China {wengrx, funan.whr, yuheng.yh, l.bing, weihua.luowh}@alibaba-inc.com, {huangsj, chenjj}@nju.edu.cn Transformer, based on the encoder-decoder framework, has achieved state-of-the-art performance on several natural language generation tasks. The encoder maps the words in the input sentence into a sequence of hidden states, which are then fed into the decoder to generate the output sentence. These hidden states usually correspond to the input words and focus on capturing local information. However, the global (sentence level) information is seldom explored, leaving room for the improvement of generation quality. In this paper, we propose a novel global representation enhanced Transformer (GRET) to explicitly model global representation in the Transformer network. Specifically, in the proposed model, an external state is generated for the global representation from the encoder. The global representation is then fused into the decoder during the decoding process to improve generation quality. We conduct experiments in two text generation tasks: machine translation and text summarization. Experimental results on four WMT machine translation tasks and LCSTS text summarization task demonstrate the effectiveness of the proposed approach on natural language generation1. 1 Introduction Transformer (Vaswani et al. 2017) has outperformed other methods on several neural language generation (NLG) tasks, like machine translation (Deng et al. 2018), text summarization (Chang, Huang, and Hsu 2018), etc. Generally, Transformer is based on the encoder-decoder framework which consists of two modules: an encoder network and a decoder network. The encoder encodes the input sentence into a sequence of hidden states, each of which corresponds to a specific word in the sentence. The decoder generates the output sentence word by word. At each decoding time-step, the decoder performs attentive read (Luong, Pham, and Manning 2015; Vaswani et al. 2017) to fetch the input hidden states and decides which word to generate. As mentioned above, the decoding process of Transformer only relies on the representations contained in these Corresponding author Copyright c 2020, Association for the Advancement of Artificial Intelligence (www.aaai.org). All rights reserved. 1Source code is available at: https://github.com/wengrx/GRET hidden states. However, there is evidence showing that hidden states from the encoder in Transformer only contain local representations which focus on word level information. For example, previous work (Vaswani et al. 2017; Devlin et al. 2018; Song et al. 2020) showed that these hidden states pay much attention to the word-to-word mapping; and the weights of attention mechanism, determining which target word will be generated, is similar to word alignment. As Frazier (1987) pointed, the global information, which is about the whole sentence in contrast to individual words, should be involved in the process of generating a sentence. Representation of such global information plays an import role in neural text generation tasks. In the recurrent neural network (RNN) based models (Bahdanau, Cho, and Bengio 2014), Chen (2018) showed on text summarization task that introducing representations about global information could improve quality and reduce repetition. Lin et al. (2018b) showed on machine translation that the structure of the translated sentence will be more correct when introducing global information. These previous work shows global information is useful in current neural network based model. However, different from RNN (Sutskever, Vinyals, and Le 2014; Cho et al. 2014; Bahdanau, Cho, and Bengio 2014) or CNN (Gehring et al. 2016; 2017), although self-attention mechanism can achieve long distance dependence, there is no explicit mechanism in the Transformer to model the global representation of the whole sentence. Therefore, it is an appealing challenge to provide Transformer with such a kind of global representation. In this paper, we divide this challenge into two issues that need to be addressed: 1). how to model the global contextual information? and 2). how to use global information in the generation process?, and propose a novel global representation enhanced Transformer (GRET) to solve them. For the first issue, we propose to generate the global representation based on local word level representations by two complementary methods in the encoding stage. On one hand, we adopt a modified capsule network (Sabour, Frosst, and Hinton 2017) to generate the global representation based the features extracted from local word level representations. The local representations are generally related to the word-to-word mapping, which may be redundant or noisy. Using them to generate the global representation directly without any filtering is inadvisable. Capsule network, which has a strong ability of feature extraction (Zhao et al. 2018), can help to extract more suitable features from local states. Comparing with other networks, like CNN (Krizhevsky, Sutskever, and Hinton 2012), it can see all local states at one time, and extract feature vectors after several times of deliberation. On the other hand, we propose a layer-wise recurrent structure to further strengthen the global representation. Previous work shows the representations from each layer have different aspects of meaning (Peters et al. 2018; Dou et al. 2018), e.g. lower layer contains more syntactic information, while higher layer contains more semantic information. A complete global context should have different aspects of information. However, the global representation generated by the capsule network only obtain intra-layer information. The proposed layer-wise recurrent structure is a helpful supplement to combine inter-layer information by aggregating representations from all layers. These two methods can model global representation by fully utilizing different grained information from local representations. For the second issue, we propose to use a context gating mechanism to dynamically control how much information from the global representation should be fused into the decoder at each step. In the generation process, every decoder states should obtain global contextual information before outputting words. And the demand from them for global information varies from word to word in the output sentence. The proposed gating mechanism could utilize the global representation effectively to improve generation quality by providing a customized representation for each state. Experimental results on four WMT translation tasks, and LCSTS text summarization task show that our GRET model brings significant improvements over a strong baseline and several previous researches. 2 Approach Our GRET model includes two steps: modeling the global representation in the encoding stage and incorporating it into the decoding process. We will describe our approach in this section based on Transformer (Vaswani et al. 2017). 2.1 Modeling Global Representation In the encoding stage, we propose two methods for modeling the global representation at different granularity. We firstly use capsule network to extract features from local word level representations, and generate global representation based on these features. Then, a layer-wise recurrent structure is adopted subsequently to strengthen the global representation by aggregating the representations from all layers of the encoder. The first method focuses on utilizing word level information to generate a sentence level representation, while the second method focuses on combining different aspects of sentence level information to obtain a more complete global representation. Intra-layer Representation Generation We propose to use capsules with dynamic routing to extract the specific and Algorithm 1 Dynamic Routing Algorithm 1: procedure: ROUTING(H, r) 2: for i in input layer and k in output layer do 3: bki 0; 4: end for 5: for r iterations do 6: for k in output layer do 7: ck softmax(bk); 8: end for 9: for k in output layer do 10: uk q( I i ckihi); 11: H = {h1, , hi, } 12: end for 13: for i in input layer and k in output layer do 14: bki bki + hi uk; 15: end for 16: end for 17: return U; U = {u1, , uk, } suitable features from the local representations for stronger global representation modeling, which is an effective and strong feature extraction method (Sabour, Frosst, and Hinton 2017; Zhang, Liu, and Song 2018)2. Features from hidden states of the encoder are summarized into several capsules, and the weights (routes) between hidden states and capsules are updated by dynamic routing algorithm iteratively. Formally, given an encoder of the Transformer which has M layers and an input sentence x = {x1, , xi, , x I} which has I words. The sequence of hidden states Hm = {hm 1 , , hm i , , hm I } from the mth layer of the encoder is computed by Hm = LN(SAN(Qm e , Km 1 e , Vm 1 e )), (1) where the Qm e , Km 1 e and Vm 1 e are query, key and value vectors which are same as Hm 1, the hidden states from the m 1th layer. The LN( ) and SAN( ) are layer normalization function (Ba, Kiros, and Hinton 2016) and self-attention network (Vaswani et al. 2017), respectively. We omit the residual network here. Then, the capsules Um with size of K are generated by Hm. Specifically, the kth capsule um k is computed by i ckiˆh m i ), cki ck, (2) ˆh m i = Wkhm i , (3) where q( ) is non-linear squash function (Sabour, Frosst, and Hinton 2017): squash(t) = ||t||2 1 + ||t||2 t ||t||, (4) and ck is computed by ck = softmax(bk), bk B, (5) 2Other details of the Capsule Network are shown in Sabour, Frosst, and Hinton (2017) . mth Layer of Encoder um 1 um 2 um k um k+1 hm i hm i+1 Local States Global State Dynamic Routing Figure 1: The overview of generating the global representation with capsule network. where the matrix B is initialized by zero and whose row and column are K and I, respectively. This matrix will be updated when all capsules are produced. B = B + Um Hm. (6) The algorithm is shown in Algorithm 1. The sequence of capsules Um could be used to generate the global representation. Different from the original capsules network which use a concatenation method to generate the final representation, we use an attentive pooling method to generate the global representation3. Formally, in the mth layer, the global representation is computed by k=1 akum k ), (7) ak = exp(ˆsm um k ) K t=1 exp(ˆsm um t ) , (8) where FFN( ) is a feed-forward network and the ˆsm is computed by sm = FFN( 1 k=1 um k ). (9) This attentive method can consider the different roles of the capsules and better model the global representation. The overview of the process of generating the global representation are shown in Figure 1. Inter-layer Representation Aggregation Traditionally, the Transformer model only fed the last layer s hidden states 3Typically, the concatenation and other pooling methods, e.g. mean pooling, could be used here easily, but they will decrease 0.1 0.2 BLEU in machine translation experiment. um 1 um 2 um K um 1 1 um 1 2 m-1th Layer mth Layer Figure 2: The overview of the layer-wise recurrent structure. HM as representations of input sentence to the decoder to generate the output sentence. Following this, we can feed the last layer s global representation s M into the decoder directly. However, current global representation only contain the intra-layer information, the other layers representations are ignored, which were shown to have different aspects of meaning in previous work (Wang et al. 2018b; Dou et al. 2018). Based on this intuition, we propose a layer-wise recurrent structure to aggregate the representations generated by employing the capsule network on all layers of the encoder to model a complete global representation. The layer-wise recurrent structure aggregates each layer s intra global state by a gated recurrent unit (Cho et al. 2014, GRU) which could achieve different aspects of information from the previous layer s global representation. Formally, we adjust the computing method of sm by sm = GRU(ATP(Um), sm 1), (10) where the ATP( ) is the attentive pooling function computed by Eq 7-9. The GRU unit can control the information flow by forgetting useless information and capturing suitable information, which can aggregate previous layer s representations usefully. The layer-wise recurrent structure could achieve a more exquisite and complete representation. Moreover, the proposed structure only need one more step in the encoding stage which is not time-consuming. The overview of the aggregation structure is shown in Figure 2. 2.2 Incorporating into the Decoding Process Before generating the output word, each decoder state should consider the global contextual information. We combine the global representation in decoding process with an additive operation to the last layer of the decoder guiding the states output true words. However, the demand for the global information of each target word is different. Thus, we propose a context gating mechanism which can provide specific information according to each decoder hidden state. Specifically, given an decoder which has N layers and the target sentence y which has J words in the training stage, the hidden states RN = {r N 1 , , r N j , , r N J } from the N th layer of the decoder is computed by RN = LN(SAN(QN d , KN 1 d , VN 1 d ) + SAN(QN d , KM e , VM e )), (11) Figure 3: The context gating mechanism of fusing the global representation into decoding stage. where QN d , KN 1 d and VN 1 d are hidden states RN 1 from N 1th layer. The KM e and VM e are same as HM. We omit the residual network here. For each hidden state r N j from RN, the context gate is calculated by: gj = sigmoid(r N j , s M). (12) The new state, which contains the needed global information, is computed by: r N j = r N j + s M j g. (13) Then, the output probability is calculated by the output layer s hidden state: P(yj|y