# learning_representations_of_sets_through_optimized_permutations__47d7c4b5.pdf Published as a conference paper at ICLR 2019 LEARNING REPRESENTATIONS OF SETS THROUGH OPTIMIZED PERMUTATIONS Yan Zhang & Adam Pr ugel-Bennett & Jonathon Hare Department of Electronics and Computer Science University of Southampton {yz5n12,apb,jsh2}@ecs.soton.ac.uk Representations of sets are challenging to learn because operations on sets should be permutation-invariant. To this end, we propose a Permutation-Optimisation module that learns how to permute a set end-to-end. The permuted set can be further processed to learn a permutation-invariant representation of that set, avoiding a bottleneck in traditional set models. We demonstrate our model s ability to learn permutations and set representations with either explicit or implicit supervision on four datasets, on which we achieve state-of-the-art results: number sorting, image mosaics, classification from image mosaics, and visual question answering. 1 INTRODUCTION Consider a task where each input sample is a set of feature vectors with each feature vector describing an object in an image (for example: person, table, cat). Because there is no a priori ordering of these objects, it is important that the model is invariant to the order that the elements appear in the set. However, this puts restrictions on what can be learned efficiently. The typical approach is to compose elementwise operations with permutation-invariant reduction operations, such as summing (Zaheer et al., 2017) or taking the maximum (Qi et al., 2017) over the whole set. Since the reduction operator compresses a set of any size down to a single descriptor, this can be a significant bottleneck in what information about the set can be represented efficiently (Qi et al., 2017; Le & Duan, 2018; Murphy et al., 2019). We take an alternative approach based on an idea explored in Vinyals et al. (2015a), where they find that some permutations of sets allow for easier learning on a task than others. They do this by ordering the set elements in some predetermined way and feeding the resulting sequence into a recurrent neural network. For instance, it makes sense that if the task is to output the top-n numbers from a set of numbers, it is useful if the input is already sorted in descending order before being fed into an RNN. This approach leverages the representational capabilities of traditional sequential models such as LSTMs, but requires some prior knowledge of what order might be useful. Our idea is to learn such a permutation purely from data without requiring a priori knowledge (section 2). The key aspect is to turn a set into a sequence in a way that is both permutation-invariant, as well as differentiable so that it is learnable. Our main contribution is a Permutation-Optimisation (PO) module that satisfies these requirements: it optimises a permutation in the forward pass of a neural network using pairwise comparisons. By feeding the resulting sequence into a traditional model such as an LSTM, we can learn a flexible, permutation-invariant representation of the set while avoiding the bottleneck that a simple reduction operator would introduce. Techniques used in our model may also be applicable to other set problems where permutation-invariance is desired, building on the literature of approaches to dealing with permutation-invariance (section 3). In four different experiments, we show improvements over existing methods (section 4). The former two tasks measure the ability to learn a particular permutation as target: number sorting and image mosaics. We achieve state-of-the-art performance with our model, which shows that our method is suitable for representing permutations in general. The latter two tasks test whether a model can learn to solve a task that requires it to come up with a suitable permutation implicitly: classification from image mosaics and visual question answering. We provide no supervision of what the permutation should be; the model has to learn by itself what permutation is most useful for the task at hand. Published as a conference paper at ICLR 2019 < < > > > < function to measure cost of permutation with improve permutation by minimising cost: P (0) P (1) P (T ) F (learned) Figure 1: Overview of Permutation-Optimisation module. In the ordering cost C, elements of X are compared to each other (blue represents a negative value, red represents a positive value). Gradients are applied to unnormalised permutations e P (t), which are normalised to proper permutations P (t). Here, our model also beats the existing models and we improve the performance of a state-of-the-art model in VQA with it. This shows that our PO module is able to learn good permutation-invariant representations of sets using our approach. 2 PERMUTATION-OPTIMISATION MODULE We will now describe a differentiable, and thus learnable model to turn an unordered set {xi}N with feature vectors as elements into an ordered sequence of these feature vectors. An overview of the algorithm is shown in Figure 1 and pseudo-code is available in Appendix A. The input set is represented as a matrix X = [x1, . . . , x N]T with the feature vectors xi as rows in some arbitrary order. In the algorithm, it is important to not rely on the arbitrary order so that X is correctly treated as a set. The goal is then to learn a permutation matrix P such that when permuting the rows of the input through Y = P X, the output is ordered correctly according to the task at hand. When an entry Pik takes the value 1, it can be understood as assigning the ith element to the kth position in the output. Our main idea is to first relate pairs of elements through an ordering cost, parametrised with a neural network. This pairwise cost tells us whether an element i should preferably be placed before or after element j in the output sequence. Using this, we can define a total cost that measures how good a given permutation is (subsection 2.1). The second idea is to optimise this total cost in each forward pass of the module (subsection 2.2). By minimising the total cost of a permutation, we improve the quality of a permutation with respect to the current ordering costs. Crucially, the ordering cost function and thus also the total cost function is learned. In doing so, the module is able to learn how to generate a permutation as is desired. In order for this to work, it is important that the optimisation process itself is differentiable so that the ordering cost is learnable. Because permutations are inherently discrete objects, a continuous relaxation of permutations is necessary. For optimisation, we perform gradient descent on the total cost for a fixed number of steps and unroll the iteration, similar to how recurrent neural networks are unrolled to perform backpropagation-through-time. Because the inner gradient (total cost differentiated with respect to permutation) is itself differentiable with respect to the ordering cost, the whole model is kept differentiable and we can train it with a standard supervised learning loss. Note that as long as the ordering cost is computed appropriately (subsection 2.3), all operations used turn out to be permutation-invariant. Thus, we have a model that respects the symmetries of sets while producing an output without those symmetries: a sequence. This can be naturally extended to outputs where the target is not a sequence, but grids and lattices (subsection 2.4). Published as a conference paper at ICLR 2019 2.1 TOTAL COST FUNCTION The total cost function measures the quality of a given permutation and should be lower for better permutations. Because this is the function that will be optimised, it is important to understand what it expresses precisely. The main ingredient for the total cost of a permutation is the pairwise ordering cost (details in subsection 2.3). By computing it for all pairs, we obtain a cost matrix C where the entry Cij represents the ordering cost between i and j: the cost of placing element i anywhere before j in the output sequence. An important constraint that we put on C is that Cij = Cji. In other words, if one ordering of i and j is good (negative cost), then the opposite ordering obtained by swapping them is bad (positive cost). Additionally, this constraint means that Cii = 0. This makes sure that two very similar feature vectors in the input will be similarly ordered in the output because their pairwise cost goes to 0. In this paper we use a straightforward definition of the total cost function: a sum of the ordering costs over all pairs of elements i and j. When considering the pair i and j, if the permutation maps i to be before j in the output sequence, this cost is simply Cij. Vice versa, if the permutation maps i to be after j in the output sequence, the cost has to be flipped to Cji. To express this idea, we define the total cost c: RN N 7 R of a permutation P as: k u and 1 when v < u; permutation matrices are binary and only have one 1 in any row and column, so all other terms in the sums are 0. That means that the term for each i and j becomes Cij when v > u and Cij = Cji when v < u, which matches what we described previously. 2.2 OPTIMISATION PROBLEM Now that we can compute the total cost of a permutation, we want to optimise this cost with respect to a permutation. After including the constraints to enforce that P is a valid permutation matrix, we obtain the following optimisation problem: minimize P c(P ) subject to i, k: Pik {0, 1}, k Pik = 1, X k Pki = 1 (2) Optimisation over P directly is difficult due to the discrete and combinatorial nature of permutations. To make optimisation feasible, a common relaxation is to replace the constraint that Pik {0, 1} with Pik [0, 1] (Fogel et al., 2013). With this change, the feasible set for P expands to the set of doublystochastic matrices, known as the Birkhoff or assignment polytope. Rather than hard permutations, we now have soft assignments of elements to positions, analogous to the latent assignments when fitting a mixture of Gaussians model using Expectation-Maximisation. Note that we do not need to change our total cost function after this relaxation. Instead of discretely flipping the sign of Cij depending on whether element i comes before j or not, the sums over k and k give us a weight for each Cij that is based on how strongly i and j are assigned to positions. This weight is positive when i is on average assigned to earlier positions than j and negative vice versa. In order to perform optimisation of the cost under our constraints, we reparametrise P with the Sinkhorn operator S from Adams & Zemel (2011) (defined in Appendix B) so that the constraints are always satisfied. We found this to lead to better solutions than projected gradient descent in initial Published as a conference paper at ICLR 2019 experiments. After first exponentiating all entries of a matrix, S repeatedly normalises all rows, then all columns of the matrix to sum to 1, which converges to a doubly-stochastic matrix in the limit. P = S( e P ) (3) This ensures that P is always approximately a doubly-stochastic matrix. e P can be thought of as the unnormalised permutation while P is the normalised permutation. By changing our optimisation to minimise e P instead of P directly, all constraints are always satisfied and we can simplify the optimisation problem to min e P c(P ) without any constraints. It is now straightforward to optimise e P with standard gradient descent. First, we compute the gradient: k