# differentiable_topk_with_optimal_transport__b9c7ede4.pdf Differentiable Top-k with Optimal Transport Yujia Xie College of Computing Georgia Tech Xie.Yujia000@gmail.com Hanjun Dai Google Brain hadai@google.com Minshuo Chen College of Engineering Georgia Tech mchen393@gatech.edu Bo Dai Google Brain bodai@google.com Tuo Zhao College of Engineering Georgia Tech tourzhao@gatech.edu Hongyuan Zha School of Data Science Shenzhen Research Institute of Big Data, CUHK, Shenzhen zhahy@cuhk.edu.cn Wei Wei Google Cloud AI wewei@google.com Tomas Pfister Google Cloud AI tpfister@google.com The top-k operation, i.e., finding the k largest or smallest elements from a collection of scores, is an important model component, which is widely used in information retrieval, machine learning, and data mining. However, if the top-k operation is implemented in an algorithmic way, e.g., using bubble algorithm, the resulting model cannot be trained in an end-to-end way using prevalent gradient descent algorithms. This is because these implementations typically involve swapping indices, whose gradient cannot be computed. Moreover, the corresponding mapping from the input scores to the indicator vector of whether this element belongs to the top-k set is essentially discontinuous. To address the issue, we propose a smoothed approximation, namely the SOFT (Scalable Optimal transport-based di Fferen Tiable) top-k operator. Specifically, our SOFT top-k operator approximates the output of the top-k operation as the solution of an Entropic Optimal Transport (EOT) problem. The gradient of the SOFT operator can then be efficiently approximated based on the optimality conditions of EOT problem. We apply the proposed operator to the k-nearest neighbors and beam search algorithms, and demonstrate improved performance. 1 Introduction The top-k operation, i.e., finding the k largest or smallest elements from a set, is widely used for predictive modeling in information retrieval, machine learning, and data mining. For example, in image retrieval (Babenko et al., 2014; Radenovi c et al., 2016; Gordo et al., 2016), one needs to query the k nearest neighbors of an input image under certain metrics; in the beam search (Reddy et al., 1977; Wiseman and Rush, 2016) algorithm for neural machine translation, one needs to find the k sequences of largest likelihoods in each decoding step. Work done in a Google internship. Also affliated with Shenzhen Institute of Artificial Intelligence and Robotics for Society. On leave from College of Computing, Georgia Tech. 34th Conference on Neural Information Processing Systems (Neur IPS 2020), Vancouver, Canada. Although the ubiquity of top-k operation continues to grow, the operation itself is difficult to be integrated into the training procedure of a predictive model. For example, we consider a neural network-based k-nearest neighbor classifier. Given an input, we use the neural network to extract features from the input. Next, the extracted features are fed into the top-k operation for identifying the k nearest neighbors under some distance metric. We then obtain a prediction based on the k nearest neighbors of the input. In order to train such a model, we choose a proper loss function, and minimize the average loss across training samples using (stochastic) first-order methods. This naturally requires the loss function being differentiable with respect to the input at each update step. Nonetheless, the top-k operation does not exhibit an explicit mathematical formulation: most implementations of the top-k operation, e.g., bubble algorithm and QUICKSELECT (Hoare, 1961), involve operations on indices such as indices swapping. Consequently, the training objective is difficult to formulate explicitly. (a) Original top-k (b) SOFT top-k Figure 1: Illustration of the top-k operators. Alternative perspective taking the top-k operation as an operator still cannot resolve the differentibility issue. Specifically, the top-k operator3 maps a set of inputs x1, . . . , xn to an index vector {0, 1}n. Whereas the Jacobian matrix of such a mapping is not well defined. As a simple example, consider two scalars x1, x2. The top-1 operation as in Figure 1 returns a vector [A1, A2]>, with each entry denoting whether the scalar is the larger one (1 for true, 0 for false). Denote A1 = f(x1, x2). For a fixed x2, A1 jumps from 0 to 1 at x1 = x2. It is clear that f is not differentiable at x1 = x2, and the derivative is identically zero otherwise. Due to the aforementioned difficulty, existing works resort to two-stage training for models with the top-k operation. We consider the neural network-based k-nearest neighbor classifier again. As proposed in Papernot and Mc Daniel (2018), one first trains the neural network using some surrogate loss on the extracted features, e.g., using softmax activation in the output layer and the cross-entropy loss. Next, one uses the k-nearest neighbor for prediction based on the features extracted by the well-trained neural network. This training procedure, although circumventing the top-k operation, makes the training and prediction misaligned; and the actual performance suffers. In this work, we propose the SOFT (Scalable Optimal transport-based di Fferen Tiable) top-k operation as a differentiable approximation of the standard top-k operation in Section. 2. Specifically, motivated by the implicit differentiation (Duchi et al., 2008; Griewank and Walther, 2008; Amos and Kolter, 2017; Luise et al., 2018) techniques, we first parameterize the top-k operation in terms of the optimal solution of an Optimal Transport (OT) problem. Such a re-parameterization is still not differentiable with respect to the input. To rule out the discontinuity, we impose entropy regularization to the optimal transport problem, and show that the optimal solution to the Entropic OT (EOT) problem yields a differentiable approximation to the top-k operation. Moreover, we prove that under mild assumptions, the approximation error can be properly controlled. We then develop an efficient implementation of the SOFT top-k operation in Section. 3. Specifically, we solve the EOT problem via the Sinkhorn algorithm (Cuturi, 2013). Given the optimal solution, we can explicitly formulate the gradient of SOFT top-k operation using the KKT (Karush-Kuhn-Tucker) condition. As a result, the gradient at each update step can be efficiently computed with complexity O(n), where n is the number of elements in the input set to the top-k operation. Our proposed SOFT top-k operation allows end-to-end training, and we apply SOFT top-k operation to k NN for classification in Section 4 and beam search in Section 5. The experimental results demonstrate significant performance gain over competing methods, as an end-to-end training procedure resolves the misalignment between training and prediction. Notations. We denote k k2 as the 2 norm of vectors, k k F as the Frobenius norm of matrices. Given two matrices B, D 2 Rn m, we denote h B, Di as the inner product, i.e., h B, Di = Pn,m i=1,j=1 Bij Dij. We denote B D as the element-wise multiplication of B and D. We denote ( ) as the indicator function, i.e., the output of ( ) is 1 if the input condition is satisfied, and is 0 otherwise. For matrix B 2 Rn m, we denote Bi,: as the i-th row of the matrix. 3Throughout the rest of the paper, we refer to the top-k operator as the top-k operation. The softmax function for matrix B is defined as softmaxi(Bij) = e Bij/ Pn =1 e Blj. For a vector b 2 Rn, we denote diag(b) as the matrix where the i-th diagonal entries is bi. 2 SOFT Top-k Operator We adopt the following definition of the (augment of) top-k operator. Given a set of scalars X = {xi}n i=1, the standard top-k operator returns a vector A = [A1, . . . , An]>, such that 1, if xi is a top-k element in X, 0, otherwise. Note that the definition is essentially an "arg-top-k" operation since it marks the top-k indices as 1, instead of returning the top-k values. This allows more flexibility since we can obtain the top-k values by multiplying A to X. The goal is to design a smooth relaxation of the standard top-k operator. Without loss of generality, we refer to top-k elements as the smallest k elements. 2.1 Parameterizing Top-k Operator as OT Problem We first show that the standard top-k operator can be parameterized in terms of the solution of an Optimal Transport (OT) problem (Monge, 1781; Kantorovich, 1960). We briefly introduce OT problems for self-containedness. An OT problem finds a transport plan between two distributions, while the expected cost of the transportation is minimized. We consider two discrete distributions defined on supports A = {ai}n i=1 and B = {bj}m j=1, respectively. Denote P({ai}) = µi and P({bj}) = j, and let µ = [µ1, . . . , µn]> and = [ 1, . . . , m]>. We further denote C 2 Rn m as the cost matrix with Cij being the cost of transporting mass from ai to bj. An OT problem can be formulated as h C, Γi, s.t., Γ1m = µ, Γ>1n = , (1) where 1 denotes a vector of ones. The optimal Γ is referred to as the optimal transport plan. In order to parameterize the top-k operator using the optimal transport plan Γ , we set the support A = X and B = {0, 1} in (1), with µ, defined as µ = 1n/n, = [k/n, (n k)/n]>. We take the cost to be the squared Euclidean distance, i.e., Ci1 = x2 i and Ci2 = (xi 1)2 for i = 1, . . . , n. We then establish the relationship between the output A of the top-k operator and Γ . Proposition 1. Consider the setup in the previous paragraph. Without loss of generality, we assume X has no duplicates. Then the optimal transport plan Γ of (1) is 1/n, if i k, 0, if k + 1 i n. , Γ 0, if i k, 1/n, if k + 1 i n, (2) with σ being the sorting permutation, i.e., xσ1 < xσ2 < < xσn. Moreover, we have A = nΓ [1, 0]>. (3) The proof can be found in Appendix A. Figure 3(a) illustrates the corresponding optimal transport plan for parameterizing the top-5 operator applied to a set of 7 elements. As can be seen, the mass from the 5 closest points is transported to 0, and meanwhile the mass from the 2 remaining points is transported to 1. Therefore, the optimal transport plan exactly indicates the top-5 elements. 2.2 Smoothing by Entropy Regularization We next rule out the discontinuity of (1) to obtain a smoothed approximation to the top-k operator. Specifically, we employ entropy regularization to the OT problem (1): Γ , = argmin h C, Γi + H(Γ), s.t., Γ1m = µ, Γ>1n = , (4) where h(Γ) = P i,j Γij log Γij is the entropy regularizer. We define A = nΓ , [0, 1]> as a smoothed counterpart of output A in the standard top-k operator. Accordingly, SOFT top-k operator is defined as the mapping from X to A . We show that the Jacobian matrix of SOFT top-k operator exists and is nonzero in the following theorem. Theorem 1. For any > 0, SOFT top-k operator: X 7! A is differentiable, as long as the cost Cij is differentiable with respect to xi for any i, j. Moreover, the Jacobian matrix of SOFT top-k operator always has a nonzero entry for any X 2 Rn. (b) = 5 10 3 (d) = 5 10 2 Figure 2: Color maps of Γ (upper) and the corresponding scatter plots of values in A (lower), where X contains 50 standard Gaussian samples, and K = 5. The scatter plots show the correspondence of the input X and output A . Figure 3: (a). Illustration of the OT plan with input X = [0.4, 0.7, 2.3, 1.9, 0.2, 1.4, 0.1]> and k = 5. We set = [ 5 7]>. In this way, 5 of the 7 scores align with 0, while {2.3, 1.9} align with 1. (b). Illustration for sorted top-k with similar input and k = 2. We set = [ 1 7]> and B = [0, 1, 2]>. Then, the smallest score 0.2 aligns with 0, the second smallest score 0.1 aligns with 1, and the rest of the scores align with 2. The proof can be found in Appendix A. We remark that the entropic OT (4) is computationally more friendly, since it allows the usage of first-order algorithms (Cuturi, 2013). The Entropic OT introduces bias to the SOFT top-k operator. The following theorem shows that such a bias can be effectively controlled. Theorem 2. Given a distinct sequence X and its sorting permutation σ, with Euclidean square cost function, for the proposed top-k solver we have kΓ , Γ k F (ln n + ln 2) n(xσk+1 xσk). Therefore, with a small enough , the output vector A can well approximate A, especially when there is a large gap between xσk and xσk+1. Besides, Theorem 2 suggests a trade-off between the bias and regularization of SOFT top-k operator. See Section 7 for a detailed discussion. 2.3 Sorted SOFT Top-k Operator In some applications, we not only need to distinguish the top-k elements, but also sort the top-k elements. For example, in image retrieval (Gordo et al., 2016), the retrieved k images are expected to be sorted. Our SOFT top-k operator can be extended to the sorted SOFT top-k operator. Analogous to the derivation of the SOFT top-k operator, we first parameterize the sorted top-k operator in terms of an OT problem. Specifically, we keep A = X and µ = 1n/n and set B = [0, 1, 2, , k]>, and = [1/n, , 1/n, (n k)/n]>. One can check that the optimal transport plan of the above OT problem transports the smallest element in A to 0 in B, the second smallest element to 1, and so on so forth. This in turn yields the sorted top-k elements. Figure 3(b) illustrates the sorted top-2 operator and its optimal transport plan. The sorted SOFT top-k operator is obtained similarly to SOFT top-k operator by solving the entropy regularized OT problem. We can show that the sorted SOFT top-k operator is differentiable and the bias can be properly controlled. 3 Efficient Implementation Algorithm 1 SOFT Top-k Require: X = [xi]n i=1, k, , L Y = [y1, y2]> = [0, 1]> µ = 1n/n, = [k/n, (n K)/n]> Cij = |xi yj|2, Gij = e , q = 12/2 for l = 1, , L do p = µ/(Gq), q = /(G>p) end for Γ = diag(p) G diag(q) A = nΓ [0, 1]> We now present our implementation of SOFT top-k operator, which consists of 1) computing A from X and 2) computing the Jacobian matrix of A with respect to X. We refer to 1) as the forward pass and 2) as the backward pass. Forward Pass The forward pass from X to A can be efficiently computed using Sinkhorn algorithm. Specifically, we run iterative Bregman projections (Benamou et al., 2015), where at the -th iteration, we update p( +1) = µ Gq( ) , q( +1) = G>p( +1) . Here, the division is entrywise, q(0) = 12/2, and G 2 Rn m with Gij = e . Denote p and q as the stationary point of the Bregman projections. The optimal transport plan Γ , can be obtained by Γ , j . The algorithm is summarized in Algorithm 1. Backward Pass. Given A , we compute the Jacobian matrix d A d X using implicit differentiation and differentiable programming techinques. Specifically, the Lagrangian function of Problem (4) is L = h C, Γi >(Γ1m µ) >(Γ>1n ) + H(Γ), where and are dual variables. The KKT condition implies that Γ , can be formulated using the optimal dual variables and as (Sinkhorn s scaling theorem, Sinkhorn and Knopp (1967)), Γ , = diag(e ). (5) Substituting (5) into the Lagrangian function, we obtain L( , ; C) = ( )>µ + ( )> We now compute the gradient of and with respect to C, such that we can obtain dΓ , /d C by the chain rule applied to (5). Denote ! = [( )>, ( )>]>, and φ(! ; C) = @L(! ; C)/@! . At the optimal dual variable ! , the KKT condition immediately yields φ(! ; C) 0. By the chain rule, we have d C = @φ(! ; C) @C + @φ(! ; C) Rearranging terms, we obtain 1 @φ(! ; C) Combining (5), (6), Cij = (xi yj)2, and A = nΓ , [1, 0]>, the Jacobian matrix d A /d X can then be derived using the chain rule again. The detailed derivation and the corresponding algorithm for computing the Jacobian matrix can be found in Appendix B. The time and space complexity of the derived algorithm is O(n) and O(kn) for top-k and sorted top-k operators, respectively. We also include a Pytorch Paszke et al. (2017) implementation of the forward and backward pass in Appendix B by extending the autograd automatic differentiation package. 4 k-NN for Image Classification The proposed SOFT top-k operator enables us to train an end-to-end neural network-based k NN classifier. Specifically, we receive training samples {Zi, yi}N i=1 with Zi being the input data and yi 2 {1, . . . , M} the label from M classes. During the training, for an input data Zj (also known as the query sample), we associate a loss as follows. Denote Z\j as all the input data excluding Zj (also known as the template samples). We use a neural network f parameterized by to extract features from all the input data, and measure the pairwise Euclidean distances between the extracted features of Z\j and that of Zj. Denote X\j, as the collection of these pairwise distances, i.e., X\j, = {kf (Z1) f (Zj)k2, ..., kf (Zj 1) f (Zj)k2, kf (Zj+1) f (Zj)k2, ..., kf (ZN) f (Zj)k2}, where the subscript of X emphasizes its dependence on . Next, we apply SOFT top-k operator to X\j,!, and the returned vector is denoted by A \j, . Let Y\j 2 RM (N 1) be the matrix by concatenating the one-hot encoding of labels yi for i 6= j as columns, and Yj 2 RM the one-hot encoding of the label yj. The loss of Zj is defined as (Zj, yj) = Y > Consequently, the training loss is L({Zj, yj}N j=1 (Zj, yj). Recall that the Jacobian matrix of A \j, exists and has no zero entries. This allows us to utilize stochastic gradient descent algorithms to update in the neural network. Moreover, since N is often large, to ease the computation, we randomly sample a batch of samples to compute the stochastic gradient at each iteration. In the prediction stage, we use all the training samples to obtain a predicted label of a query sample. Specifically, we feed the query sample into the neural network to extract its features, and compute pairwise Euclidean distances to all the training samples. We then run the standard k NN algorithm (Hastie et al., 2009) to obtain the predicted label. Figure 4: Illustration of the entire forward pass of k NN. Table 1: Classification accuracy of k NN. Algorithm MNIST CIFAR10 k NN 97.2% 35.4% k NN+PCA 97.6% 40.9% k NN+AE 97.6% 44.2% k NN+pretrained CNN 98.4% 91.1% Relax Sub Sample 99.3% 90.1% k NN+Neural Sort 99.5% 90.7% k NN+Cuturi et al. (2019) 99.0% 84.8% k NN+Softmax k times 99.3% 92.2% CE+CNN (He et al., 2016) 99.0% 91.3% k NN+SOFT Top-k 99.4% 92.6% 4.1 Experiment We evaluate the performance of the proposed neural network-based k NN classifier on two benchmark datasets: MNIST dataset of handwritten digits (Le Cun et al., 1998) and the CIFAR-10 dataset of natural images (Krizhevsky et al., 2009) with the canonical splits for training and testing without data augmentation. We adopt the coefficient of entropy regularizer = 10 3 for MNIST dataset and = 10 5 for CIFAR-10 dataset. Further implementation details can be found in Appendix C. Baselines. We consider several baselines: 1. Standard k NN method. 2. Two-stage training methods: we first extract the features of the images, and then perform k NN on the features. The feature is extracted using Principle Component Analysis (PCA, top-50 principle components is adopted), autoencoder (AE), or a pretrained Convolutional Neural Network (CNN) using the Cross-Entropy (CE) loss. 3. Differentiable ranking + k NN: This includes Neural Sort (Grover et al., 2019) and Cuturi et al. (2019). Cuturi et al. (2019) is not directly applicable, which requires adaptations (see Appendix C). 4. Stochastic k NN with Gumbel top-k relaxation (Xie and Ermon, 2019): The model is referred as Relax Sub Sample. 5. Softmax Augmentation for smoothed top-k operation: A combination of k softmax operation is used to replace the top-k operator. Specifically, we recursively perform softmax on X for k times (Similar idea appears in Plötz and Roth (2018)). At the k-th iteration, we mask the top-(k 1) entries with negative infinity. 6. CNNs trained with CE without any top-k component4. For the pretrained CNN and CNN trained with CE, we adopt identical neural networks as our method. Results. We report the classification accuracies on the standard test sets in Table 1. On both datasets, the SOFT k NN classifier achieves comparable or better accuracies. 5 Beam Search for Machine Translation Beam search is a popular method for the inference of Neural Language Generation (NLG) models, e.g., machine translation models. Here, we propose to incorporate beam search into the training procedure based on SOFT top-k operator. 4Our implementation is based on github.com/pytorch/vision.git 5.1 Misalignment between Training and Inference Denote the predicted sequence as y = [y(1), , y(T )], and the vocabularies as {z1, , z V }. Consider a recurrent network based NLG model. The output of the model at the t-th decoding step is a probability simplex [P(y(t) = zi|h(t)]V i=1, where h(t) is the hidden state associated with the sequence y(1:t) = [y(1), ..., y(t)]. Beam search recursively keeps the sequences with the k largest likelihoods, and discards the rest. Specifically, at the (t + 1)-th decoding step, we have k sequences ey(1:t),i s obtained at the t-th step, where i = 1, ..., k indexes the sequences. The likelihood of ey(1:t),i is denoted by Ls(ey(1:t),i). We then select the next k sequences by varying i = 1, . . . , k and j = 1, . . . , V : {ey(1:t+1), }k =1 = arg top-k[ey(1:t),i,zj]Ls([ey(1:t),i, zj]). where Ls([ey(1:t),i, zj]) is the likelihood of the sequence appending zj to ey(1:t),i defined as Ls([ey(1:t),i, zj])=P(y(t+1) =zj|h(t+1),i)Ls(ey(1:t),i), (7) and h(t+1),i is the hidden state generated from ey(1:t),i. Note that zj s and ey(1:t),i s together yield V k choices. Here we abuse the notation: ey(1:t+1), denotes the -th selected sequence at the (t + 1)-th decoding step, and is not necessarily related to ey(1:t),i at the t-th decoding step, even if i = . For t = 1, we set ey(1) = zs as the start token, Ls(y(1)) = 1, and h(1) = he as the output of the encoder. We repeat the above procedure, until the end token is selected or the pre-specified max length is reached. At last, we select the sequence y(1:T ), with the largest likelihood as the prediction. Moreover, the most popular training procedure for NLG models directly uses the so-called teacher forcing framework. As the ground truth of the target sequence (i.e., gold sequence) y = [ y(1), , y(T )] is provided at the training stage, we can directly maximize the likelihood P(y(t) = y(t)|h(t)( y(1:t-1))). (8) As can be seen, such a training framework only involve the gold sequence, and cannot take the uncertainty of the recursive exploration of the beam search into consideration. Therefore, it yields a misalignment between model training and inference (Bengio et al., 2015), which is also referred as exposure bias (Wiseman and Rush, 2016). 5.2 Differential Beam Search with Sorted SOFT Top-k To mitigate the aforementioned misalignment, we propose to integrate beam search into the training procedure, where the top-k operator in the beam search algorithm is replaced with our proposed sorted SOFT top-k operator proposed in Section 2.3. Specifically, at the (t + 1)-th decoding step, we have k sequences denoted by E(1:t),i, where i = 1, ..., k indexes the sequences. Here E(1:t),i consists of a sequence of D-dimensional vectors, where D is the embedding dimension. We are not using the tokens, and the reason behind will be explained later. Let eh(t),i denote the hidden state generated from E(1:t),i. We then consider X (t) = { Ls([E(1:t),i, wj]), j = 1, ..., V, i = 1, ..., k}, where Ls( ) is defined analogously to (7), and wj 2 RD is the embedding of token zj. Recall that is the smoothing parameter. We then apply the sorted SOFT top-k operator to X (t) to obtain {E(1:t+1), }k =1, which are k sequences with the largest likelihoods. More precisely, the sorted SOFT top-k operator yields an output tensor A(t), 2 RV k k, where A(t), ji, denotes the smoothed indicator of whether [E(1:t),i, wj] has a rank . We then obtain E(1:t+1), = where r denotes the index i (for E(1:t),i s) associated with the index (for E(1:t+1), s). This is why we use vector representations instead of tokens: this allows us to compute E(t+1), as a weighted sum of all the word embeddings [wj]V j=1, instead of discarding the un-selected words. Accordingly, we generate the k hidden states for the (t + 1)-th decoding step: ji, h(t),i, (10) where h(t),i is the hidden state generated by the decoder based on E(1:t),i. After decoding, we select the sequence with largest likelihood E(1:T ), , and maximize the likelihood as follows, P(y(t) = y(t)|eh(t-1), (E(1:t-1), )). We provide the sketch of training procedure in Algorithm 2, where we denote logit(t),i as [log P(y(t+1) = !j|eh(t),i(E(1:t),i))]V j=1, which is part of the output of the decoder. More technical details (e.g., backtracking algorithm for finding the index r in (9)) are provided in Appendix C. Note that integrating the beam search into training essentially yields a very large search space for the model, which is not necessarily affordable sometimes. To alleviate this issue, we further propose a hybrid approach by combining the teacher forcing training with beam search-type training. Specifically, we maximize the weighted likelihood defined as follows, Lfinal = Ltf + (1 )LSOFT, where 2 (0, 1) is referred to as the teaching forcing ratio . The teaching forcing loss Ltf can help reduce the search space and improve the overall performance. 5.3 Experiment Algorithm 2 Beam search training with SOFT Top-k Require: Input sequence s, target sequence y; embedding matrix W 2 RV D; max length T; k; regularization coefficient ; number of Sinkhorn iteration L eh(1) i = he = Encoder(s), E(1),i = ws for t = 1, , T 1 do for i = 1, , k do logit(t),i, h(t),i = Decoder(E(t),i,eh(t),i) log Ls([E(1:t),i, wj]) = log Ls(E(1:t),i)+logit(t),i j X (t) = { log Ls([E(1:t),i, wj]) | j = 1, , V } end for A(t), = Sorted-SOFT-Top-k(X (t), k, , L) Compute E(t+1), , eh(t+1), as in (9) and (10) end for Compute r LSOFT and update the model We evaluate our proposed beam search + sorted SOFT top-k training procedure using WMT2014 English French dataset. We adopt beam size 5, teacher forcing ratio = 0.8, and = 10 1. For detailed settings of the training procedure, please refer to Appendix C. We reproduce the experiment in Bahdanau et al. (2014), and run our proposed training procedure with the identical data pre-processing procedure and the LSTM-based sequenceto-sequence model. Different from Bahdanau et al. (2014), here we also preprocess the data with byte pair encoding (Sennrich et al., 2015). Results. As shown in Table 2, the proposed SOFT beam search training procedure achieves an improvement in BLEU score of approximately 0.9. We also include other LSTM-based models for baseline comparison. Ablation study. We replace the SOFT top-k operator with a vanilla top-k operator, i.e., we ignore the gradient of the top-k operation. The obtained BLEU score is 35.84, which suggest a) our SOFT top-k operator and b) incorporating beam search into training both contribute to the improved performance. 6 Related Work We parameterize the top-k operator as an optimal transport problem, which shares the same spirit as Cuturi et al. (2019). Specifically, Cuturi et al. (2019) formulate the ranking and sorting problems as OT problems. Ranking is more complicated than identifying the top-k elements, since one needs to align different ranks to corresponding elements. Therefore, the algorithm complexity per iteration for ranking whole n elements is O(n2). Cuturi et al. (2019) also propose an OT problem for finding the -quantile in a set of n elements and the algorithm complexity reduces to O(n). Top-k operator essentially finds all the elements more extreme than the (n k)/n-quantile, and our proposed algorithm achieves the same complexity O(n) per iteration. The difference is that top-k operator returns the top-k elements in a given input set, while finding a quantile only yields a certain threshold. Table 2: BLEU on WMT 14 with single LSTM. Algorithm BLEU Luong et al. (2014) 33.10 Durrani et al. (2014) 30.82 Cho et al. (2014) 34.54 Sutskever et al. (2014) 30.59 Bahdanau et al. (2014) 28.45 Jean et al. (2014) 34.60 Bahdanau et al. (2014) (Our implementation) 35.38 Beam Search + Sorted SOFT Top-k 36.27 Gumbel-Softmax trick (Jang et al., 2016) can also be utilized to derive a continuous relaxation of the top-k operator. Specifically, Kool et al. (2019) adapted such a trick to sample k elements from n choices, and Xie and Ermon (2019) further applied the trick to stochastic k NN, where neural networks are used to approximating the sorting operator. However, as shown in our experiments (see Table 1), the performance of stochastic k NN is not as good as deterministic k NN. Our SOFT beam search training procedure is inspired by several works that incorporate some of the characteristics of beam search into the training procedure (Wiseman and Rush, 2016; Goyal et al., 2018; Bengio et al., 2015). Specifically, Wiseman and Rush (2016) and Goyal et al. (2018) both address the exposure bias issue in beam search. Wiseman and Rush (2016) propose a new loss function in terms of the error made during beam search. This mitigates the misalignment of training and testing in beam search. Later, Goyal et al. (2018) approximates the top-k operator using k softmax operations (This method is described and compared to our proposed method in 4). Such an approximation allows an end-to-end training of beam search. Besides, our proposed training loss Lfinal is inspired by Bengio et al. (2015), which combines teacher forcing training procedure and greedy decoding, i.e., beam search with beam size 1. 7 Discussion Relation to automatic differentiation. We compute the Jacobian matrix of SOFT top-k operator directly in the backward pass. The OT plan can be obtained by the Sinkhorn algorithm (Algorithm 1), which is iterative and each iteration only involves multiplication and addition. Therefore, we can also apply automatic differentiation (auto-diff) to compute the Jacobian matrix. Specifically, we denote Γ as the transport plan at the t-th iteration of Sinkhorn algorithm. The update of Γ can be written as Γ +1 = T (Γ ), where T denotes the update of the Sinkhorn algorithm. In order to apply auto-diff, we need to store all the intermediate states, e.g., p, q, G in each iteration, as defined in Algorithm 1 at each iteration. This requires a huge memory size proportional to the number of iterations of the algorithm. In contrast, our backward pass allows us to save memory. Figure 5: Visualization of MNIST data based on features extracted by the neural network-based k-NN classifier trained by our proposed method in Section 4. Bias and regularization trade-off. Theorem 2 suggests a trade-off between the regularization and bias of SOFT top-k operator. Specifically, a large has a strong smoothing effect on the entropic OT problem, and the corresponding entries of the Jacobian matrix are neither too large nor too small. This eases the end-to-end training process. However, the bias of SOFT top-k operator is large, which can deteriorate the model performance. On the contrary, a smaller ensures a smaller bias. Yet the SOFT top-k operator is less smooth, which in turn makes the end-to-end training less efficient. On the other hand, the bias of SOFT top-k operator also depends on the gap between xσk+1 and xσk. In fact, such a gap can be viewed as the signal strength of the problem. A large gap implies that the top-k elements are clearly distinguished from the rest of the elements. Therefore, the bias is expected to be small since the problem is relatively easy. Moreover, in real applications such as neural network-based k NN classification, the end-to-end training process promotes neural networks to extract features that exhibit a large gap (as illustrated in Figure 5). Hence, the bias of SOFT top-k operator can be well controlled in practice. 8 Broader Impact This paper makes a significant contribution to extending the frontier of the end-to-end training of compositional models. To the best of our knowledge, our method is the first work targeting at efficient end-to-end training with top-k operation. We remark that our proposed SOFT top-k operator can be integrated into many existing machine learning methods, and has a great potential to become a standard routine in various applications such as computer vision, natural language processing, healthcare, and computational social science. Acknowledgement We thank Marco Cuturi and Jean-Philippe Vert who provided insight and expertise that greatly assisted the research. We are also grateful to Kihyuk Sohn for comments that greatly improved our earlier version of the manuscript. We thank the anonymous reviewers for their careful reading of our manuscript and their many insightful comments and suggestions. AMOS, B. and KOLTER, J. Z. (2017). Optnet: Differentiable optimization as a layer in neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org. BABENKO, A., SLESAREV, A., CHIGORIN, A. and LEMPITSKY, V. (2014). Neural codes for image retrieval. In European conference on computer vision. Springer. BAHDANAU, D., CHO, K. and BENGIO, Y. (2014). Neural machine translation by jointly learning to align and translate. ar Xiv preprint ar Xiv:1409.0473. BENAMOU, J.-D., CARLIER, G., CUTURI, M., NENNA, L. and PEYRÉ, G. (2015). Iterative bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37 A1111 A1138. BENGIO, S., VINYALS, O., JAITLY, N. and SHAZEER, N. (2015). Scheduled sampling for sequence prediction with recurrent neural networks. In Advances in Neural Information Processing Systems. CHO, K., VAN MERRIËNBOER, B., GULCEHRE, C., BAHDANAU, D., BOUGARES, F., SCHWENK, H. and BENGIO, Y. (2014). Learning phrase representations using rnn encoder-decoder for statistical machine translation. ar Xiv preprint ar Xiv:1406.1078. CUTURI, M. (2013). Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in neural information processing systems. CUTURI, M., TEBOUL, O. and VERT, J.-P. (2019). Differentiable ranking and sorting using optimal transport. In Advances in Neural Information Processing Systems. DUCHI, J., SHALEV-SHWARTZ, S., SINGER, Y. and CHANDRA, T. (2008). Efficient projections onto the l 1-ball for learning in high dimensions. In Proceedings of the 25th international conference on Machine learning. DURRANI, N., HADDOW, B., KOEHN, P. and HEAFIELD, K. (2014). Edinburgh s phrase-based machine translation systems for wmt-14. In Proceedings of the Ninth Workshop on Statistical Machine Translation. GORDO, A., ALMAZÁN, J., REVAUD, J. and LARLUS, D. (2016). Deep image retrieval: Learning global representations for image search. In European conference on computer vision. Springer. GOYAL, K., NEUBIG, G., DYER, C. and BERG-KIRKPATRICK, T. (2018). A continuous relaxation of beam search for end-to-end training of neural sequence models. In Thirty-Second AAAI Conference on Artificial Intelligence. GRIEWANK, A. and WALTHER, A. (2008). Evaluating derivatives: principles and techniques of algorithmic differentiation, vol. 105. Siam. GROVER, A., WANG, E., ZWEIG, A. and ERMON, S. (2019). Stochastic optimization of sorting networks via continuous relaxations. ar Xiv preprint ar Xiv:1903.08850. HASTIE, T., TIBSHIRANI, R. and FRIEDMAN, J. (2009). The elements of statistical learning: data mining, inference, and prediction. Springer Science & Business Media. HE, K., ZHANG, X., REN, S. and SUN, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition. HOARE, C. A. (1961). Algorithm 65: Find. Commun. ACM, 4 321 322. JANG, E., GU, S. and POOLE, B. (2016). Categorical reparameterization with gumbel-softmax. ar Xiv preprint ar Xiv:1611.01144. JEAN, S., CHO, K., MEMISEVIC, R. and BENGIO, Y. (2014). On using very large target vocabulary for neural machine translation. ar Xiv preprint ar Xiv:1412.2007. KANTOROVICH, L. V. (1960). Mathematical methods of organizing and planning production. Management science, 6 366 422. KLEIN, G., KIM, Y., DENG, Y., SENELLART, J. and RUSH, A. M. (2017). Open NMT: Open-source toolkit for neural machine translation. In Proc. ACL. https://doi.org/10.18653/v1/P17-4012 KOOL, W., VAN HOOF, H. and WELLING, M. (2019). Stochastic beams and where to find them: The gumbel-top-k trick for sampling sequences without replacement. ar Xiv preprint ar Xiv:1903.06059. KRIZHEVSKY, A., HINTON, G. ET AL. (2009). Learning multiple layers of features from tiny LECUN, Y., BOTTOU, L., BENGIO, Y. and HAFFNER, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86 2278 2324. LUISE, G., RUDI, A., PONTIL, M. and CILIBERTO, C. (2018). Differential properties of sinkhorn ap- proximation for learning with wasserstein distance. In Advances in Neural Information Processing Systems. LUONG, M.-T., SUTSKEVER, I., LE, Q. V., VINYALS, O. and ZAREMBA, W. (2014). Addressing the rare word problem in neural machine translation. ar Xiv preprint ar Xiv:1410.8206. MONGE, G. (1781). Mémoire sur la théorie des déblais et des remblais. Histoire de l Académie Royale des Sciences de Paris. PAPERNOT, N. and MCDANIEL, P. (2018). Deep k-nearest neighbors: Towards confident, inter- pretable and robust deep learning. ar Xiv preprint ar Xiv:1803.04765. PASZKE, A., GROSS, S., CHINTALA, S., CHANAN, G., YANG, E., DEVITO, Z., LIN, Z., DESMAI- SON, A., ANTIGA, L. and LERER, A. (2017). Automatic differentiation in pytorch. PLÖTZ, T. and ROTH, S. (2018). Neural nearest neighbors networks. In Advances in Neural Information Processing Systems. RADENOVI C, F., TOLIAS, G. and CHUM, O. (2016). Cnn image retrieval learns from bow: Unsuper- vised fine-tuning with hard examples. In European conference on computer vision. Springer. REDDY, D. R. ET AL. (1977). Speech understanding systems: A summary of results of the five-year research effort. department of computer science. SCHLEMPER, J., OKTAY, O., SCHAAP, M., HEINRICH, M., KAINZ, B., GLOCKER, B. and RUECKERT, D. (2019). Attention gated networks: Learning to leverage salient regions in medical images. Medical image analysis, 53 197 207. SENNRICH, R., HADDOW, B. and BIRCH, A. (2015). Neural machine translation of rare words with subword units. ar Xiv preprint ar Xiv:1508.07909. SHANKAR, S., GARG, S. and SARAWAGI, S. (2018). Surprisingly easy hard-attention for sequence to sequence learning. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing. SINKHORN, R. and KNOPP, P. (1967). Concerning nonnegative matrices and doubly stochastic matrices. Pacific Journal of Mathematics, 21 343 348. SUTSKEVER, I., VINYALS, O. and LE, Q. V. (2014). Sequence to sequence learning with neural networks. In Advances in neural information processing systems. WISEMAN, S. and RUSH, A. M. (2016). Sequence-to-sequence learning as beam-search optimization. ar Xiv preprint ar Xiv:1606.02960. XIE, S. M. and ERMON, S. (2019). Reparameterizable subset sampling via continuous relaxations. In International Joint Conference on Artificial Intelligence. ZHU, C., TAN, X., ZHOU, F., LIU, X., YUE, K., DING, E. and MA, Y. (2018). Fine-grained video categorization with redundancy reduction attention. In Proceedings of the European Conference on Computer Vision (ECCV).