# loss_functions_for_multiset_prediction__33066e3f.pdf Loss Functions for Multiset Prediction Sean Welleck1,2, Zixin Yao1, Yu Gai1, Jialin Mao1, Zheng Zhang1, Kyunghyun Cho2,3 1New York University Shanghai 2New York University 3CIFAR Azrieli Global Scholar {wellecks,zixin.yao,yg1246,jialin.mao,zz,kyunghyun.cho}@nyu.edu We study the problem of multiset prediction. The goal of multiset prediction is to train a predictor that maps an input to a multiset consisting of multiple items. Unlike existing problems in supervised learning, such as classification, ranking and sequence generation, there is no known order among items in a target multiset, and each item in the multiset may appear more than once, making this problem extremely challenging. In this paper, we propose a novel multiset loss function by viewing this problem from the perspective of sequential decision making. The proposed multiset loss function is empirically evaluated on two families of datasets, one synthetic and the other real, with varying levels of difficulty, against various baseline loss functions including reinforcement learning, sequence, and aggregated distribution matching loss functions. The experiments reveal the effectiveness of the proposed loss function over the others. 1 Introduction A relatively less studied problem in machine learning, particularly supervised learning, is the problem of multiset prediction. The goal of this problem is to learn a mapping from an arbitrary input to a multiset1 of items. This problem appears in a variety of contexts. For instance, in the context of high-energy physics, one of the important problems in a particle physics data analysis is to count how many physics objects, such as electrons, muons, photons, taus, and jets, are in a collision event [5]. In computer vision, object counting and automatic alt-text can be framed as multiset prediction [25, 12]. In multiset prediction, a learner is presented with an arbitrary input and the associated multiset of items. It is assumed that there is no predefined order among the items, and that there are no further annotations containing information about the relationship between the input and each of the items in the multiset. These properties make the problem of multiset prediction unique from other well-studied problems. It is different from sequence prediction, because there is no known order among the items. It is not a ranking problem, since each item may appear more than once. It cannot be transformed into classification, because the number of possible multisets grows exponentially with respect to the maximum multiset size. In this paper, we view multiset prediction as a sequential decision making process. Under this view, the problem reduces to finding a policy that sequentially predicts one item at a time, while the outcome is still evaluated based on the aggregate multiset of the predicted items. We first propose an oracle policy that assigns non-zero probabilities only to prediction sequences that result exactly in the target, ground-truth multiset given an input. This oracle is optimal in the sense that its prediction never decreases the precision and recall regardless of previous predictions. That is, its decision is optimal in any state (i.e., prediction prefix). We then propose a novel multiset loss which minimizes 32nd Conference on Neural Information Processing Systems (Neur IPS 2018), Montréal, Canada. 1A set that allows multiple instances, e.g. {x, y, x}. See Appendix A for a detailed definition. the KL divergence between the oracle policy and a parametrized policy at every point in a decision trajectory of the parametrized policy. We compare the proposed multiset loss against an extensive set of baselines. They include a sequential loss with an arbitrary rank function, sequential loss with an input-dependent rank function, and an aggregated distribution matching loss and its one-step variant. We also test policy gradient, as was done in [25] recently for multiset prediction. Our evaluation is conducted on two sets of datasets with varying difficulties and properties. According to the experiments, we find that the proposed multiset loss outperforms all the other loss functions. 2 Multiset Prediction A multiset prediction problem is a generalization of classification, where a target is not a single class but a multiset of classes. The goal is to find a mapping from an input x to a multiset Y = y1, . . . , y|Y| , where yk C. Some of the core properties of multiset prediction are; (1) the input x is an arbitrary vector, (2) there is no predefined order among the items yi in the target multiset Y, (3) the size of Y may vary depending on the input x, and (4) each item in the class set C may appear more than once in Y. Formally, Y is a multiset Y = (µ, C), where µ : C N gives the number of occurrences of each class c C in the multiset. See Appendix A for a further review of multisets. As is typical in supervised learning, in multiset prediction a model fθ(x) is trained on a dataset {(xi, Yi)}N i=1, then evaluated on a separate test set {(xi, Yi)}n i=1 using evaluation metrics m( , ) that compare the predicted and target multisets, i.e. 1 n Pn i=1 m( ˆYi, Yi), where ˆYi = fθ(xi) denotes a predicted multiset. For evaluation metrics we use exact match EM( ˆY, Y) = I[ ˆY = Y], and the F1 score. Refer to Appendix A for multiset definitions of exact match and F1. 3 Related Problems in Supervised Learning Variants of multiset prediction have been studied earlier. We now discuss a taxonomy of approaches in order to differentiate our proposal from previous work and define strong baselines. 3.1 Set Prediction Ranking A ranking problem can be considered as learning a mapping from a pair of input x and one of the items c C to its score s(x, c). All the items in the class set are then sorted according to the score, and this sorted order determines the rank of each item. Taking the top-K items from this sorted list results in a predicted set (e.g. [6]). Similarly to multiset prediction, the input x is arbitrary, and the target is a set without any prespecific order. However, ranking differs from multiset prediction in that it is unable to handle multiple occurrences of a single item in the target set. Multi-label Classification via Binary Classification Multi-label classification consists of learning a mapping from an input x to a subset of classes identified as y {0, 1}|C|. This problem can be reduced to |C| binary classification problems by learning a binary classifier for each possible class. Representative approaches include binary relevance, which assumes classes are conditionally independent, and probabilistic classifier chains which decompose the joint probability as p(y|x) = Q|C| c=1 p(yc|y H(π(t+1) ), where H(π(t) ) denotes the Shannon entropy of the oracle policy at time t, π (y|ˆy H(π(t+1) ). This naturally follows from the fact that there is no pre-specified rank function, because the oracle policy cannot prefer any item from the others in a free label multiset. Hence, we examine here how the policy learned based on each loss function compares to the oracle policy in terms of per-step entropy. We consider the policies trained on MNIST Multi (10), where the differences among them were most clear. As shown in Fig. 1, the policy trained on MNIST Multi (10) using the proposed multiset loss closely follows the oracle policy. The entropy decreases as the predictions are made. The decreases can be interpreted as concentrating probability mass on progressively smaller free labels sets. The variance is quite small, indicating that this strategy is uniformly applied for any input. The policy trained with reinforcement learning retains a relatively low entropy across steps, with a decreasing trend in the second half. We carefully suspect the low entropy in the earlier steps is due to the greedy nature of policy gradient. The policy receives a high reward more easily by choosing one of many possible choices in an earlier step than in a later step. This effectively discourages the policy from exploring all possible trajectories during training. On the other hand, the policy found by aggregated distribution matching (LKL dm) has the opposite behaviour. The entropy in general grows as more predictions are made. To see why this is sub-optimal, consider the final step. Assuming the first nine predictions were correct, there is only one correct class left for the final prediction . The high entropy, however, indicates that the model is placing a significant amount of probability on incorrect sequences. Such a policy may result because LKL dm cannot properly distinguish between policies with increasing and decreasing entropies. The increasing entropy also indicates that the policy has learned a rank function implicitly and is fully relying on it. We conjecture this reliance on an inferred rank function, which is by definition sub-optimal, resulted in lower performance of aggregate distribution matching. 6 Conclusion We have extensively investigated the problem of multiset prediction in this paper. We rigorously defined the problem, and proposed to approach it from the perspective of sequential decision making. In doing so, an oracle policy was defined and shown to be optimal, and a new loss function, called multiset loss, was introduced as a means to train a parametrized policy for multiset prediction. The experiments on two families of datasets, MNIST Multi variants and MS COCO variants, have revealed the effectiveness of the proposed loss function over other loss functions including reinforcement learning, sequence, and aggregated distribution matching loss functions. This success brings in new opportunities of applying machine learning to various new domains, including high-energy physics. Acknowledgments KC thanks support by e Bay, Ten Cent, NVIDIA and CIFAR. This work was supported by Samsung Electronics (Improving Deep Learning using Latent Structure) and 17JC1404101 STCSM. [1] Kai-Wei Chang, Akshay Krishnamurthy, Alekh Agarwal, Hal Daumé, III, and John Langford. Learning to search better than your teacher. In Proceedings of the 32Nd International Conference on International Conference on Machine Learning - Volume 37, ICML 15, pages 2058 2066. JMLR.org, 2015. [2] Hal Daumé, John Langford, and Daniel Marcu. Search-based structured prediction. Machine Learning, 75(3):297 325, Jun 2009. [3] Krzysztof Dembczy nski, Weiwei Cheng, and Eyke Hüllermeier. Bayes optimal multilabel classification via probabilistic classifier chains. In Proceedings of the 27th International Conference on International Conference on Machine Learning, ICML 10, pages 279 286, USA, 2010. Omnipress. [4] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on, pages 248 255. IEEE, 2009. [5] W Ehrenfeld, R Buckingham, J Cranshaw, T Cuhadar Donszelmann, T Doherty, E Gallas, J Hrivnac, D Malon, M Nowak, M Slater, F Viegas, E Vinek, Q Zhang, and the ATLAS Collaboration. Using tags to speed up the atlas analysis process. Journal of Physics: Conference Series, 331(3):032007, 2011. [6] Yunchao Gong, Yangqing Jia, Thomas Leung, Alexander Toshev, and Sergey Ioffe. Deep convolutional ranking for multilabel image annotation. ar Xiv preprint ar Xiv:1312.4894, 2013. [7] S. Hamid Rezatofighi, Vijay Kumar B G, Anton Milan, Ehsan Abbasnejad, Anthony Dick, and Ian Reid. Deepsetnet: Predicting sets with deep neural networks. In The IEEE International Conference on Computer Vision (ICCV), Oct 2017. [8] Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. Mask R-CNN. ar Xiv preprint ar Xiv:1703.06870, 2017. [9] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770 778, 2016. [10] Rémi Leblond, Jean-Baptiste Alayrac, Anton Osokin, and Simon Lacoste-Julien. Searnn: Training rnns with global-local losses, 2017. [11] Yann Le Cun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. [12] V. Lempitsky and A. Zisserman. Learning to count objects in images. In Advances in Neural Information Processing Systems, 2010. [13] Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C Lawrence Zitnick. Microsoft coco: Common objects in context. In European conference on computer vision, pages 740 755. Springer, 2014. [14] Jinseok Nam, Eneldo Loza Mencía, Hyunwoo J Kim, and Johannes Fürnkranz. Maximizing subset accuracy with recurrent neural networks in multi-label classification. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems 30, pages 5419 5429. Curran Associates, Inc., 2017. [15] Daniel Oñoro-Rubio and Roberto J. López-Sastre. Towards perspective-free object counting with deep learning. In Bastian Leibe, Jiri Matas, Nicu Sebe, and Max Welling, editors, Computer Vision ECCV 2016, pages 615 629, Cham, 2016. Springer International Publishing. [16] J. Peters and S. Schaal. Reinforcement learning of motor skills with policy gradients. Neural Networks, 21(4):682 697, May 2008. [17] Jesse Read, Bernhard Pfahringer, Geoff Holmes, and Eibe Frank. Classifier chains for multilabel classification. Machine Learning, 85(3):333, Jun 2011. [18] Mengye Ren and Richard S. Zemel. End-to-end instance segmentation with recurrent attention. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), July 2017. [19] Bernardino Romera-Paredes and Philip H. S. Torr. Recurrent instance segmentation. 2015. [20] Stéphane Ross, Geoffrey J Gordon, and Drew Bagnell. A reduction of imitation learning and structured prediction to no-regret online learning. In International Conference on Artificial Intelligence and Statistics, pages 627 635, 2011. [21] Russell Stewart, Mykhaylo Andriluka, and Andrew Y. Ng. End-to-end people detection in crowded scenes. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016. [22] Grigorios Tsoumakas and Ioannis Katakis. Multi-label classification: An overview. Int J Data Warehousing and Mining, 2007:1 13, 2007. [23] Oriol Vinyals, Samy Bengio, and Manjunath Kudlur. Order matters: Sequence to sequence for sets, 2015. [24] Jiang Wang, Yi Yang, Junhua Mao, Zhiheng Huang, Chang Huang, and Wei Xu. Cnn-rnn: A unified framework for multi-label image classification. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016. [25] Sean Welleck, Kyunghyun Cho, and Zheng Zhang. Saliency-based sequential image attention with multiset prediction. In Advances in neural information processing systems, 2017. [26] Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229 256, 1992. [27] SHI Xingjian, Zhourong Chen, Hao Wang, Dit-Yan Yeung, Wai-Kin Wong, and Wang-chun Woo. Convolutional lstm network: A machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, pages 802 810, 2015. [28] Y. Zhang, D. Zhou, S. Chen, S. Gao, and Y. Ma. Single-image crowd counting via multi-column convolutional neural network. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 589 597, June 2016.