# learning_optimal_tree_models_under_beam_search__d0dea861.pdf Learning Optimal Tree Models under Beam Search Jingwei Zhuo 1 Ziru Xu 1 Wei Dai 1 Han Zhu 1 Han Li 1 Jian Xu 1 Kun Gai 1 Retrieving relevant targets from an extremely large target set under computational limits is a common challenge for information retrieval and recommendation systems. Tree models, which formulate targets as leaves of a tree with trainable node-wise scorers, have attracted a lot of interests in tackling this challenge due to their logarithmic computational complexity in both training and testing. Tree-based deep models (TDMs) and probabilistic label trees (PLTs) are two representative kinds of them. Though achieving many practical successes, existing tree models suffer from the training-testing discrepancy, where the retrieval performance deterioration caused by beam search in testing is not considered in training. This leads to an intrinsic gap between the most relevant targets and those retrieved by beam search with even the optimally trained node-wise scorers. We take a first step towards understanding and analyzing this problem theoretically, and develop the concept of Bayes optimality under beam search and calibration under beam search as general analyzing tools for this purpose. Moreover, to eliminate the discrepancy, we propose a novel algorithm for learning optimal tree models under beam search. Experiments on both synthetic and real data verify the rationality of our theoretical analysis and demonstrate the superiority of our algorithm compared to state-of-the-art methods. 1. Introduction Extremely large-scale retrieval problems prevail in modern industrial applications of information retrieval and recommendation systems. For example, in online advertising systems, several advertisements need to be retrieved from a target set containing tens of millions of advertisements and 1Alibaba Group. Correspondence to: Jingwei Zhuo . Proceedings of the 37 th International Conference on Machine Learning, Online, PMLR 119, 2020. Copyright 2020 by the author(s). presented to a user in tens of milliseconds. The limits of computational resources and response time make models, whose computational complexity scales linearly with the size of target set, become unacceptable in practice. Tree models are of special interest to solve these problems because of their ability in achieving logarithmic complexity in both training and testing. Tree-based deep models (TDMs) (Zhu et al., 2018; 2019; You et al., 2019) and Probabilistic label trees (PLTs) (Jasinska et al., 2016; Prabhu et al., 2018; Wydmuch et al., 2018) are two representative kinds of tree models. These models introduce a tree hierarchy in which each leaf node corresponds to a target and each non-leaf node defines a pseudo target for measuring the existence of relevant targets on the subtree rooted at it. Each node is also associated with a node-wise scorer which is trained to estimate the probability that the corresponding (pseudo) target is relevant. To achieve logarithmic training complexity, a subsampling method is leveraged to select logarithmic number of nodes on which the scorers are trained for each training instance. In testing, beam search is usually used to retrieve relevant targets in logarithmic complexity. As a greedy method, beam search only expands parts of nodes with larger scores while pruning other nodes. This character achieves logarithmic computational complexity but may result in deteriorating retrieval performance if ancestor nodes of the most relevant targets are pruned. An ideal tree model should guarantee no performance deterioration when its node-wise scorers are leveraged for beam search. However, existing tree models ignore this and treat training as a separated task to testing: (1) Node-wise scorers are trained as probability estimators of pseudo targets which are not designed for optimal retrieval; (2) They are also trained on subsampled nodes which are different to those queried by beam search in testing. Such discrepancy makes even the optimal node-wise scorers w.r.t. training loss can lead to suboptimal retrieval results when they are used in testing to retrieve relevant targets via beam search. To the best of our knowledge, there is little work discussing this problem either theoretically or experimentally. We take a first step towards understanding and resolving the training-testing discrepancy on tree models. To analyze this formally, we develop the concept of Bayes optimality under beam search and calibration under beam search as Learning Optimal Tree Models under Beam Search the optimality measure of tree models and corresponding training loss, respectively. Both of them serve as general analyzing tools for tree models. Based on these concepts, we show that neither TDMs nor PLTs are optimal, and derive a sufficient condition for the existence of optimal tree models as well. We also propose a novel algorithm for learning such an optimal tree model. Our algorithm consists of a beam search aware subsampling method and an optimal retrieval based definition of pseudo targets, both of which resolve the training-testing discrepancy. Experiments on synthetic and real data not only verify the rationality of our newly proposed concepts in measuring the optimality of tree models, but also demonstrate the superiority of our algorithm compared to existing state-of-the-art methods. 2. Related Work Tree Models: Research on tree models1 has mainly focused on formulating node-wise scorers and the tree structure. For node-wise scorers, linear models are widely adopted (Jasinska et al., 2016; Wydmuch et al., 2018; Prabhu et al., 2018), while deep models (Zhu et al., 2018; 2019; You et al., 2019) become popular recently. For the tree structure, apart from the random tree (Jasinska et al., 2016), recent works propose to learn it either via hierarchical clustering over targets (Wydmuch et al., 2018; Prabhu et al., 2018; Khandagale et al., 2019) or under a joint optimization framework with node-wise scorers (Zhu et al., 2019). Without dependence on specific formulations of node-wise scorers or the tree structure, our theoretical findings and proposed training algorithm are general and applicable to these advances. Bayes Optimality and Calibration: Bayes optimality and calibration have been extensively investigated on flat models (Lapin et al., 2017; Menon et al., 2019; Yang & Koyejo, 2019), and they have also been used to measure the performance of tree models on hierarchical probability estimation (Wydmuch et al., 2018). However, there is a gap between the performance on hierarchical probability estimation and that on retrieving relevant targets, since the former ignores beam search and corresponding performance deterioration. As a result, how to measure the retrieval performance of tree models formally remains an open question. We fill this void by developing the concept of Bayes optimality under beam search and calibration under beam search. Beam Search in Training: Formulating beam search into training to resolve the training-testing discrepancy is not a new idea. It has been extensively investigated on structured prediction models for problems like machine translation and speech recognition (Daum e III & Marcu, 2005; Xu & Fern, 2007; Ross et al., 2011; Wiseman & Rush, 2016; Goyal 1There also exist models which usually build an ensemble of decision trees over instances instead of targets (Prabhu & Varma, 2014; Jain et al., 2016). They are less relevant to our main focus. et al., 2018; Negrinho et al., 2018). Though performance deterioration caused by beam search has been analyzed empirically (Cohen & Beck, 2019), it still lacks a theoretical understanding. Besides, little effort has been made to understand and resolve the training-testing discrepancy on tree models. We take a first step towards studying these problems both theoretically and experimentally. 3. Preliminaries 3.1. Problem Definition Suppose I = {1, ..., M} with M 1 is a target set and X is an observation space, we denote an instance2 as (x, Ix), implying that an observation x X is associated with a subset of relevant targets Ix I, which usually satisfies |Ix| M. For notation simplicity, we introduce a binary vector y Y = {0, 1}M as an alternative representation for Ix, where yj = 1 implies j Ix and vice versa. As a result, an instance can also be denoted as (x, y) X Y. Let p : X Y R+ be a probability density function for data which is unknown in practice, we slightly abuse notations by regarding an instance (x, y) as either the random variable pair w.r.t. p(x, y) or a sample of p(x, y). We also assume the training dataset Dtr and the testing dataset Dte to be the sets containing i.i.d. samples of p(x, y). Since y is a binary vector, we use the simplified notation ηj(x) = p(yj = 1|x) for any j I in the rest of this paper. Given these notations, the extremely large-scale retrieval problem is defined as to learn a model M such that its retrieved subset for any x p(x), denoted by either ˆIx or ˆy, is as close as y p(y|x) according to some performance metrics. Since p(x, y) is unknown in practice, such a model is usually learnt as an estimator of p(y|x) on Dtr and its retrieval performance is evaluated on Dte. 3.2. Tree Models Suppose T is a b-arity tree with height H, we regard the node at the 0-th level as the root and nodes at the H-th level as leaves. Formally, we denote the node set at h-th level as Nh and the node set of T as N = SH h=0 Nh. For each node n N, we denote its parent as ρ(n) N, its children set as C(n) N, the path from the root to it as Path(n), and the set of leaves on its subtree as L(n). Tree models formulate the target set I as leaves of T through a bijective mapping π : NH I, which implies H = O(logb M). For any instance (x, y), each node n N is defined with a pseudo target zn {0, 1} to measure the 2This summarizes many practical applications. For example, in recommendation systems, an instance corresponds to an interaction between users and items, where x denotes the user information and Ix denotes the items in which the user are interested. Learning Optimal Tree Models under Beam Search existence of relevant targets on the subtree of n, i.e., n L(n) yπ(n ) 1), (1) which satisfies zn = yπ(n) for n NH. By doing so, tree models transform the original problem of estimating p(yj|x) to a series of hierarchical subproblems of estimating p(zn|x) on n Path(π 1(j)). They introduce the node-wise scorer g : X N R to build such a node-wise estimator for each n N, which is denoted as pg(zn|x) to distinguish from the unknown distribution p(zn|x). In the rest of this paper, we denote a tree model as M(T , g) to highlight its dependence on T and g. 3.2.1. TRAINING OF TREE MODELS The training loss of tree models can be written as argming P (x,y) Dtr L(y, g(x)), where L(y, g(x)) = n Sh(y) ℓBCE (zn, g(x, n)) . (2) In Eq. (2), g(x) is a vectorized representation of {g(x, n) : n N} (e.g., level-order traversal), ℓBCE(z, g) = z log(1 + exp( g)) (1 z) log(1 + exp(g)) is the binary cross entropy loss and Sh(y) Nh is the set of subsampled nodes at h-th level for an instance (x, y). Let C = maxh |Sh(y)|, the training complexity is O(Hb C) per instance, which is logarithmic to the target set size M. As two representatives of tree models, PLTs and TDMs adopt different ways3 to build pg and Sh(y). PLTs: Since p(zn|x) can be decomposed as p(zn = 1|x) = Q n Path(n) p(zn = 1|zρ(n ) = 1, x) according to Eq. (1), pg(zn|x) is decomposed accordingly via pg(zn |zρ(n ) = 1, x) = 1/(1 + exp( (2zn 1)g(x, n ))). As a result, only nodes with zρ(n) = 1 are trained, which produces Sh(y) = {n : zρ(n) = 1, n Nh}. TDMs: Unlike PLTs, p(zn|x) is estimated directly via pg(zn|x) = 1/(1 + exp( (2zn 1)g(x, n))). Besides, the subsample set4 is chosen as Sh(y) = S+ h (y) S S h (y) where S+ h (y) = {n : zn = 1, n Nh} and S h (y) contains several random samples over Nh \ S+ h (y). 3.2.2. TESTING OF TREE MODELS For any testing instance (x, y), let Bh(x) denote the node set at h-th level retrieved by beam search and k = |Bh(x)| 3Details can be found in the supplementary materials. 4Zhu et al. (2018) defines TDM with the constraint |Ix| = 1, we extend their definition by removing this constraint and refer TDM to such an extended definition in the rest of this paper. denote the beam size, the beam search process is defined as Bh(x) arg Topk n Bh(x) pg(zn = 1|x), (3) where Bh(x) = S n Bh 1(x) C(n ). By applying Eq. (3) recursively until h = H, beam search retrieves the set containing k leaf nodes, denoted by BH(x). Let m k denote the number of targets to be retrieved, the retrieved target subset can be denoted as ˆIx = {π(n) : n B(m) H (x)}, (4) where B(m) H (x) arg Topmn BH(x) pg(zn = 1|x) denote the subset of BH(x) with top-m scored nodes according to pg(zn = 1|x). Since Eq. (3) only traverses at most bk nodes and generating BH(x) needs computing Eq. (3) for H times, the testing complexity is O(Hbk) per instance, which is also logarithmic to M. To evaluate the retrieval performance of M(T , g) on the testing dataset Dte, Precision@m, Recall@m and Fmeasure@m are widely adopted. Following Zhu et al. (2018; 2019), we define5 them as the average of Eq. (5), Eq. (6) and Eq. (7) over Dte respectively, where P@m(M; x, y) = 1 j ˆIx yj, (5) R@m(M; x, y) = 1 |Ix| j ˆIx yj, (6) F@m(M; x, y) = 2 P@m(M; x, y) R@m(M; x, y) P@m(M; x, y) + R@m(M; x, y) . 4. Main Contributions Our main contributions can be divided into three parts: (1) We highlight the existence of the training-testing discrepancy on tree models, and provide an intuitive explanation of its negative effects on retrieval performance; (2) We develop the concept of Bayes optimality under beam search and calibration under beam search to formalize this intuitive explanation; (3) We propose a novel algorithm for learning tree models that are Bayes optimal under beam search. 4.1. Understanding the Training-Testing Discrepancy on Tree Models According to Eq. (2), the training of g(x, n) depends on two factors: the subsample set Sh(y) and the pseudo target zn. 5Unlike macro/micro F-measure, the average of Eq. (7) over Dte defines the instance-wise F-measure. Wu & Zhou (2017, Table 1) provides a thorough comparison for them. Learning Optimal Tree Models under Beam Search Nodes with 𝑧! = 1 ℬ'" 𝐱 ℬ"(𝐱) Relevant targets (𝑦# = 1) Beam ℬ"(𝐱) 𝑝(𝑦# = 1|𝐱) 1 2 3 4 5 6 7 8 10 11 12 13 14 0.8 0.3 0.7 0.2 0.4 0.9 0.5 0.4 1 2 3 4 5 6 7 8 10 11 12 13 14 0.8 0.3 0.7 0.2 0.4 0.9 0.5 0.4 1 2 3 4 5 6 7 8 10 11 12 13 0.8 0.3 0.7 0.2 0.4 0.9 0.5 0.4 Nodes with 𝑧! = 1 (a) (b) (c) Figure 1. An overview of the training-testing discrepancy on a tree model M(T , g). (a) The assignment of pseudo targets on existing tree models, where red nodes correspond to zn = 1 defined in Eq. (1). (b) Beam search process, where targets mapping blue nodes at 3-th level (i.e., leaf nodes) are regarded as the retrieval results of M. (c) The assignment of optimal pseudo targets based on the ground truth distribution ηj(x) = p(yj = 1|x), where green nodes correspond to z n = 1 defined in Eq. (13). Table 1. Results for the toy experiment with M = 1000, b = 2. The reported number is (P j I(k) ηj P j ˆI ηj)/k, which is averaged over 100 runs with random initialization over T and ηj. N 100 1000 10000 k = 1 0.095 0.076 0.074 0.059 k = 5 0.075 0.055 0.050 0.037 k = 10 0.062 0.043 0.036 0.024 k = 20 0.057 0.036 0.031 0.018 k = 50 0.042 0.021 0.016 0.011 We can show that both factors relate to the training-testing discrepancy on existing tree models. First, according to Eq. (3), the nodes at h-th level on which g(x, n) is queried in testing can be denoted as Bh(x), which implies a self-dependency of g(x, n), i.e., nodes on which g(x, n) is queried at h-th level depends on g(x, n) queried at (h 1)-th level. However, Sh(y), the nodes at h-th level on which g(x, n) is trained, is generated according to ground truth targets y via Eq. (1). Figure 1(a) and Figure 1(b) demonstrate such a difference: Node 7 and 8 (blue nodes) are traversed by beam search, but they are not in Sh(y) of PLTs and may not be in Sh(y) of TDMs according to S+ h (y) (red nodes). As a result, g(x, n) is trained without considering such a self-dependency on itself when it is used for retrieving relevant targets via beam search. This discrepancy results in that g(x, n) trained well does not perform well in testing. Second, zn defined in Eq. (1) does not guarantee beam search w.r.t. pg(zn = 1|x) has no performance deterioration, i.e., retrieving the most relevant targets. To see this, we design a toy example by ignoring x and defining the data distribution to be p(y) = QM j=1 p(yj), whose marginal probability ηj = p(yj = 1) is sampled from a uniform distribution in [0, 1]. As a result, we denote the training dataset as Dtr = {y(i)}N i=1 and the pseudo target for instance y(i) on node n as z(i) n . For M(T , g), we assume T is randomly built and estimate p(zn = 1) directly via6 pg(zn = 1) = PN i=1 z(i) n /N without the need to specify g, since there is no observation x. Beam search with beam size k is applied on M to retrieve the target subset whose size is m = k as well, denoted by ˆI = {π(n) : n BH}. Since p(y) is known in this toy example, we need no testing set Dte and evaluate the retrieval performance directly via the regret (P j I(k) ηj P j ˆI ηj)/k, where I(k) arg Topkj I ηj denotes the top-k targets according to ηj. As a special case of Eq. (10), this metric quantifies the suboptimality of M and we ll discuss it formally later. As is shown in Table 1, we can find that the regret is always non-zero with varying training data number N and beam size k. Even in the ideal case when N = and thus pg(zn = 1) = p(zn = 1), it is still non-zero. This implies that zn defined in Eq. (1) cannot guarantee optimal retrieval performance in general. This phenomenon does not contradict with the zero regret property in Wydmuch et al. (2018, Theorem 2), since their theorem defines the regret using ˆI = arg Topkj I pg(yj = 1|x), which ignores the performance deterioration caused by beam search. 4.2. Bayes Optimality and Calibration under Beam Search In Sec. 4.1, we discuss the existence of the training-testing discrepancy on tree models and provide a toy example to ex- 6Estimating p(zn = 1) hierarchically via pg(zn = 1) = Q n Path(n) pg(zn = 1|zρ(n ) = 1) and pg(zn = 1|zρ(n) = 1) = PN i=1 z(i) n z(i) ρ(n)/ PN i=1 z(i) ρ(n) provides similar results. Learning Optimal Tree Models under Beam Search plain its effect. Without loss of generality, we formalize this discussion with Precision@m as the retrieval performance metric in this subsection. The first question is, what does optimal mean for tree models with respect to their retrieval performance. In fact, the answer has been partially revealed by the toy example in Sec. 4.1, and we give a formal definition as follows: Definition 1 (Bayes Optimality under Beam Search). Given the beam size k and the data distribution p : X Y R+, a tree model M(T , g) is called top-k Bayes optimal under beam search if {π(n) : n BH(x)} arg Topk j I ηj(x), (8) holds for any x X. M(T , g) is called Bayes optimal under beams search if Eq. (8) holds for any x X and 1 k M. Given Definition 1, we can derive a sufficient condition for the existence of such an optimal tree model as follows7 : Proposition 1 (Sufficient Condition for Bayes Optimality under Beam Search). Given the beam size k, the data distribution p : X Y R+, the tree T and max n L(n) ηπ(n )(x), zn = 1 1 max n L(n) ηπ(n )(x), zn = 0 , (9) a tree model M(T , g) is top-m Bayes optimal under beam search for any m k, if pg(zn|x) = p (zn|x) holds for any x X and n SH h=1 Bh(x). M(T , g) is Bayes optimal under beam search, if pg(zn|x) = p (zn|x) holds for any x X and n N. Proposition 1 shows one case of what an optimal tree model should be, but it does not resolve all the problems, since both learning and evaluating a tree model require a quantitative measure of its suboptimality. Notice that Eq. (8) implies that Ep(x) h P j I(k) x ηj(x) i = Ep(x) h P n BH(x) ηπ(n)(x) i , where I(k) x = arg Topkj I ηj(x) denotes the top-k targets according to the ground truth ηj(x). The deviation of such an equation can be used as a suboptimality measure of M. Formally, we define it to be the regret w.r.t. Precision@k and denote it as regp@k(M). This is a special case when m = k for a more general definition regp@m(M) = n B(m) H (x) ηπ(n)(x) 7Without any formal proof, Zhu et al. (2018) proposes the maxheap like formulation, which can be regarded as a special case of Proposition 1 with the |Ix| = 1 restriction. We provide a detailed proof for Proposition 1 in the supplementary materials. where I(m) x = arg Topmj I ηj(x). Though regp@k(M) seems an ideal suboptimality measure, finding its minimizer is hard due to the existence of a series of nested non-differentiable arg Topk operators. Therefore, finding a surrogate loss for regp@k(M) such that its minimizer is still an optimal tree model becomes very important. To distinguish such a surrogate loss, we introduce the concept of calibration under beam search as follows: Definition 2 (Calibration under Beam Search). Given a tree model M(T , g), a loss function L : {0, 1}M R|N | R is called top-k calibrated under beam search if argmin g Ep(x,y) [L(y, g(x))] argmin g regp@k(M), (11) holds for any distribution p : X Y R+. L is called calibrated under beam search if Eq. (11) holds for any 1 k M. Definition 2 shows a tree model M(T , g) with g minimizing a non-calibrated loss is not Bayes optimal under beam search in general. Recall that Proposition 1 shows that for any p : X Y R+ and any T , the minimizer of regp@k(M) always exists, which satisfies pg(zn|x) = p (zn|x) and achieves regp@k(M) = 0. Therefore, the suboptimality of TDMs and PLTs can be proved by showing the minimizer of their training loss does not guarantee regp@k(M) = 0 in general. This can be proved by finding a counterexample and the toy experiment shown in Table 1 meets this requirement. As a result, we have Proposition 2. Eq. (2) with zn defined in Eq. (1) is not calibrated under beam search in general. 4.3. Learning Optimal Tree Models under Beam Search Given the discussion in Sec. 4.2, we need a new surrogate loss function such that its minimizer corresponds to the tree model which is Bayes optimal under beam search. According to Definition 1, when the retrieval performance is measured by Precision@m, requiring a model to be top-m Bayes optimal under beam search will be enough. Proposition 1 provides a natural surrogate loss to achieve this purpose with beam size k m, i.e., g argmin g Ep(x) n Bh(x) KL(p (zn|x) pg(zn|x)) (12) where we follow the TDM style and assume pg(zn|x) = 1/(1 + exp( (2zn 1)g(x, n))). Unlike Eq. (2), Eq. (12) uses nodes in Bh(x) instead of S+ h (y) for training and introduces a different definition of pseudo targets compared to Eq. (1). Let z n p (zn|x) Learning Optimal Tree Models under Beam Search denote the corresponding pseudo target, we have z n = yπ(n ), n argmax n L(n) ηπ(n )(x). (13) Notice that for n NH, z n = yπ(n) as well as zn in Eq. (1). To distinguish z n from zn, we call it the optimal pseudo target since it corresponds to the optimal tree model. Given this definition, Eq. (12) can be rewritten as argming Ep(x,y) [Lp(y, g(x))] where Lp(y, g(x)) = n Bh(x) ℓBCE(z n, g(x, n)). (14) Notice that in Eq. (14) we assign a subscript p to highlight the dependence of z n on ηj(x), which implies that Eq. (14) is calibrated under beam search in the sense that its formulation depends on p : X Y R+. Figure 1 provides a concrete example for the difference between z n and zn. Not all ancestor nodes of a relevant target yj = 1 are regarded as relevant nodes according to z n: Node 1 and 6 (red nodes in Figure 1(a)) are assigned with zn = 1 but with z n = 0 (green nodes in Figure 1(c)). The reason is that among targets on the subtree rooted at these nodes, the irrelevant target has a higher ηj(x) compared to the relevant target, i.e., η7(x) = 0.5 > η8(x) = 0.4 and η1(x) = 0.8 > η3(x) = 0.7, which leads z n to be 0. However, it is impossible to minimize Eq. (14) directly, since ηj(x) is unknown in practice. As a result, we need to find an approximation of z n without the dependence on ηj(x). Suppose g(x, n) is parameterized with trainable parameters θ Θ, we use the notation gθ(x, n), pgθ(x) and Bh(x; θ) to highlight their dependence on θ. A natural choice is to replace ηπ(n )(x) in Eq. (13) with pgθ(zn = 1|x). However, this formulation is still impractical since the computational complexity of traversing L(n) for each n Bh(x; θ) is unacceptable. Thanks to the tree structure, we can approximate z n with ˆzn(x; θ), which is constructed in a recursive manner for n N \ NH as ˆzn(x; θ) = ˆzn (x; θ), n argmax n C(n) pgθ(zn = 1|x), (15) and is set directly as ˆzn(x; θ) = yπ(n) for n NH. By doing so, we remove the dependence on unknown ηj(x). But minimizing Eq. (14) when replacing z n with ˆzn(x, θ) is still not an easy task since the parameter θ affects Bh(x; θ), ˆzn(x, θ) and gθ(x, n): Gradient with respect to θ cannot be computed directly due to the non-differentiability of the arg Topk operator in Bh(x; θ) and the argmax operator in ˆzn(x; θ). To get a differentiable loss function, we propose to replace Lp(y, g(x)) defined in Eq. (14) with Lθt(y, g(x); θ) = n Bh(x;θt) ℓBCE(ˆzn(x; θt), gθ(x, n)), (16) where θt denotes the fixed parameter, which can be the parameter of the last iteration in a gradient based algorithm. Given the discussion above, we propose a novel algorithm for learning such a tree model as Algorithm 1. Algorithm 1 Learning Optimal Tree Models under Beam Search Input: Training dataset Dtr, initial tree model M(T , gθ0) with the tree structure T and the node-wise scorer gθ0(x, n), beam size k, stepsize ϵ. Output: Trained tree model M(T , gθt). 1: Initialize t = 0; 2: while convergence condition is not attained do 3: Draw a minibatch MB from Dtr; 4: Draw Bh(x; θt) according to Eq. (3); 5: Compute ˆzn(x; θt) for each n Bh(x; θt) according to Eq. (15); 6: Update θt+1 using a gradient based method with stepsize ϵ, e.g., ADAM (Kingma & Ba, 2015), on the current θt and the gradient gt, where (x,y) MB Lθt(y, g(x); θ) θ=θt according to Eq. (16); 7: t t + 1; 8: end while 9: Return M(T , gθt). As is analyzed in the supplementary materials, the training complexity of Algorithm 1 is O(Hbk + Hb|Ix|) per instance, which is still logarithmic to M. Besides, for the tree model trained according to Algorithm 1, its testing complexity is O(Hbk) per instance as that in Sec. 3.2.2, since Algorithm 1 does not alter beam search in testing. Now, the remaining question is, since introducing several approximations into Eq. (16), does it still have the nice property to achieve Bayes optimality under beam search? We provide an answer8 as follows: Proposition 3 (Practical Algorithm). Suppose G = {gθ : θ Θ} has enough capacity and L θt(y, g(x); θ) = n Nh wn(x, y; θt)ℓBCE(ˆzn(x; θt), gθ(x, n)), (17) where wn(x, y; θt) > 0. For any probability p : X Y 8Proof can be found in the supplementary materials. Learning Optimal Tree Models under Beam Search R+, if there exists θt Θ such that θt argmin θ Θ Ep(x,y) L θt(y, g(x); θ) , (18) the corresponding tree model M(T , gθt) is Bayes optimal under beam search. Proposition 3 shows that replacing z n with ˆzn(x, θ) and introducing the fixed parameter θt does not affect the optimality of M(T , gθt) on Eq. (17). However, Eq. (16) does not have such a guarantee, since the summation over Bh(x; θt) corresponds to the summation over Nh with weight wn(x, y; θt) = I(n Bh(x; θt)) and thus violating the restriction that wn(x, y; θt) > 0. This problem can be solved by introducing randomness into Eq. (16) such that each n Nh has a non-zero wn(x, y; θt) in expectation. Examples include adding random samples of Nh into the summation in Eq. (16) or leveraging stochastic beam search (Kool et al., 2019) to generate Bh(x; θt). Nevertheless, in experiments we find these strategies do not greatly affect the performance, and thus we still use Eq. (16). 5. Experiments In this section, we experimentally verify our analysis and evaluate the performance of different tree models on both synthetic and real data. Throughout experiments, we use OTM to denote the tree model trained according to Algorithm 1 since its goal is to learn optimal tree models under beam search. To perform an ablation study, we consider two variants of OTM: OTM (-BS) differs from OTM by replacing Bh(x; θt) with Sh(y) = S+ h (y) S S h (y), and OTM (-Opt Est) differs from OTM by replacing ˆzn(x; θt) in Eq. (13) with zn in Eq. (1). More details of experiments can be found in the supplementary materials. 5.1. Synthetic Data Datasets: For each instance (x, y), x Rd is sampled from a d-dimensional isotropic Gaussian distribution N(0d, Id) with zero mean and identity covariance matrix, and y {0, 1}M is sampled from p(y|x) = QM j=1 p(yj|x) = QM j=1 1/(1 + exp( (2yj 1)w j x b)) where the weight vector wj Rd is also sampled from N(0d, Id). The bias b is a predefined constant9 to control the number of non-zero entries in y. Corresponding training and testing datasets are denoted as Dtr and Dte, respectively. Compared Models and Metric: We compare OTM with PLT and TDM. All the tree models M(T , g) share the same tree structure T and the same parameterization of the node- 9In experiment, we set b to be a negative value such that the number of non-zero entries is less than 0.1M to simulate the practical case where the number of relevant targets is much smaller than the target set size. Table 2. A comparison of c regp@m(M) averaged by 5 runs with random initialization with hyperparameter settings M = 1000, d = 10, b = 5, |Dtr| = 10000, |Dte| = 1000 and k = 50. m 1 10 20 50 PLT 0.0444 0.0778 0.0955 0.1492 TDM 0.0033 0.0205 0.0453 0.1363 OTM 0.0024 0.0163 0.0349 0.1083 OTM (-BS) 0.0048 0.0201 0.0421 0.1313 OTM (-Opt Est) 0.0033 0.0198 0.0418 0.1218 wise scorer g. More specifically, T is set to be a random binary tree over I and g(x, n) = θ n x+bn is parameterized as a linear scorer, where θn Rd and bn R are trainable parameters. All models are trained on Dtr and their perfomance is measured by c regp@m, which is an estimation of regp@m(M) defined in Eq. (10) by replacing the expectation over p(x) with the summation over (x, y) Dte. Results: Table 2 shows that OTM performs the best compared to other models, which indicates that eliminating the training-testing discrepancy can improve retrieval performance of tree models. Both OTM (-BS) and OTM (-Opt Est) have smaller regret than PLT and TDM, which means that using beam search aware subsampling (i.e., Bh(x; θt)) or estimated optimal pseudo targets (i.e., ˆzn(x; θt)) alone contributes to better performance. Besides, OTM (-Opt Est) has smaller regret than OTM (-BS), which reveals that beam search aware subsampling contributes more than estimated optimal pseudo targets to the performance of OTM. 5.2. Real Data Datasets: Our experiment are conducted on two large-scale real datasets for recommendation tasks: Amazon Books (Mc Auley et al., 2015; He & Mc Auley, 2016) and User Behavior (Zhu et al., 2018). Each record of both datasets is organized in the format of user-item interaction, which contains user ID, item ID and timestamp. The original interaction records are formulated as a set of user-based data. Each user-based data is denoted as a list of items sorted by the timestep that the user-item interaction occurs. We discard the user based data which has less than 10 items and split the rest into training set Dtr, validation set Dval and testing set Dte in the same way as Zhu et al. (2018; 2019). For the validation and testing set, we take the first half of each user-based data according to ascending order along timestamp as the feature x, and the latter half as the relevant targets y. While training instances are generated from the raw user-based data considering the characteristics of different approaches on the training set. If the approach restricts |Ix| = 1, we use a sliding window to produce several instances for each user based data, while one instance Learning Optimal Tree Models under Beam Search Table 3. Precision@m, Recall@m and F-Measure@m comparison on Amazon Books with beam size k = 400 and various m (%). Method Precision Recall F-Measure 10 50 100 200 10 50 100 200 10 50 100 200 Item-CF 2.02 1.04 0.74 0.52 2.14 4.71 6.29 8.18 1.92 1.55 1.23 0.92 You Tube product-DNN 1.26 0.84 0.67 0.53 1.12 3.52 5.41 8.26 1.05 1.21 1.09 0.93 HSM 1.50 0.93 0.73 0.54 1.25 3.59 5.59 8.04 1.21 1.30 1.18 0.95 PLT 1.85 1.26 0.99 0.75 1.57 4.87 7.35 10.59 1.48 1.74 1.57 1.29 JTM 1.84 1.34 1.07 0.80 1.75 5.79 8.70 12.60 1.60 1.94 1.73 1.40 OTM 3.12 1.97 1.49 1.06 2.76 8.16 11.86 16.36 2.58 2.80 2.39 1.86 OTM (-BS) 2.18 1.45 1.15 0.86 1.91 6.01 9.40 13.68 1.81 2.08 1.88 1.52 OTM (-Opt Est) 3.07 1.92 1.45 1.05 2.70 8.00 11.63 16.17 2.54 2.74 2.33 1.83 Table 4. Precision@m, Recall@m and F-Measure@m comparison on User Behavior with beam size k = 400 and various m (%). Method Precision Recall F-Measure 10 50 100 200 10 50 100 200 10 50 100 200 Item-CF 5.45 3.07 2.20 1.56 1.25 3.31 4.74 6.75 1.84 2.76 2.64 2.30 You Tube product-DNN 9.04 4.52 3.22 2.25 2.29 5.36 7.49 10.15 3.29 4.23 3.97 3.36 HSM 9.79 4.49 3.04 2.01 2.58 5.60 7.38 9.52 3.68 4.30 3.80 3.03 PLT 11.47 5.07 3.47 2.35 2.85 5.84 7.75 10.22 4.13 4.72 4.23 3.48 JTM 20.05 7.45 4.85 3.12 5.42 9.39 11.84 14.75 7.62 7.15 6.06 4.70 OTM 22.47 8.21 5.33 3.42 5.95 10.07 12.62 15.68 8.40 7.78 6.59 5.12 OTM (-BS) 19.81 7.74 5.08 3.31 5.36 9.57 12.14 15.29 7.54 7.36 6.30 4.95 OTM (-Opt Est) 22.38 8.20 5.33 3.40 5.92 10.06 12.61 15.61 8.36 7.78 6.59 5.08 is obtained for methods without restriction on |Ix|. Compared Models and Metric: We compare OTM with two series of methods: (1) widely used methods in recommendation tasks, such as Item-CF (Sarwar et al., 2001), the basic collaborative filtering method, and You Tube product DNN (Covington et al., 2016), the representative work of vector k NN based methods; (2) tree models like HSM (Morin & Bengio, 2005), PLT and JTM (Zhu et al., 2019). HSM is a hierarchical softmax model which can be regarded as PLT with the |Ix| = 1 restriction. JTM is a variant of TDM which trains tree structure and node-wise scorers jointly and achieves state-of-the-art performance on these two datasets. All the tree models share the same binary tree structure and adopt the same neural network model for node-wise scorers. The neural network consists of three fully connected layers with hidden size 128, 64 and 24 and parametric Re LU is used as the activation function. The performance of different models is measured by Precision@m (Eq. (5)), Recall@m (Eq. (6)) and F-Measure@m (Eq. (7)) averaged over the testing set Dte. Results: Table 3 and Table 4 show results of Amazon Books and User Behavior, respectively10. Our model performs the 10As k = 400, OTM is trained on | Bh(x; θt)| = 800 nodes per level. For fairness in comparison, JTM also subsample |Sh(y)| = best among all methods: Compared to the previous state-ofthe-art JTM, OTM achieves 29.8% and 6.3% relative recall lift (m = 200) on Amazon Books and User Behavior separately. Results of OTM and its two variants are consistent with that on synthetic data: Both beam search aware subsampling and estimated optimal pseudo targets contribute to better performance, while the former contributes more and the performance of OTM mainly depends on the former. Besides, the comparison between HSM and PLT also demonstrates that removing the restriction of |Ix| = 1 in tree models contributes to performance improvement. To understand why OTM achieves more significant improvement (29.8% versus 6.3%) on Amazon Books than User Behavior, we analyze the statistics of these datasets and their corresponding tree structure. For each n N, we define Sn = P (x,y) Dtr zn to count the number of training instances which are relevant to n (i.e., zn = 1). For each level 1 h H, we sort {Sn : n Nh} in a descending order and normalize them as Sn/ P n Nh Sn . This produces a level-wise distribution, which reflects the data imbalance on relevant nodes resulted from the intrinsic property of both the datasets and the tree structure. As is shown in Figure 2, the level-wise distribution of User Behavior has a 800 nodes per level for training g(x, n). Learning Optimal Tree Models under Beam Search 0 50 100 150 200 250 0.00 Amazon Books User Behavior Figure 2. Results of Sn/ P n Nh Sn versus sorted node index on Amazon Books and User Behavior with h = 8 (|Nh| = 256). heavier tail than that of Amazon Books at the same level. This implies the latter has a higher proportion of instances concentrated on only parts of nodes, which makes it easier for beam search to retrieve relevant nodes for training and thus leads to more significant improvement. To verify our analysis on the time complexity of tree models, we compare their empirical training time, since they share the same beam search process in testing. More specifically, we compute the wall-clock time per batch for training PLT, TDM and PLT with batch size 100 on the User Behavior dataset. This number is averaged over 5000 training iterations on a single Tesla P100-PCIE-16GB GPU. The results are 0.184s for PLT, 0.332s for TDM and 0.671s for OTM, respectively. Though OTM costs longer time than PLT and JTM, they have the same order of magnitude. This is not weird, since the step 4 and 5 in Algorithm 1 only increases the constant factor of complexity. Besides, this is a reasonable trade-off for better performance and distributed training can alleviate this in practical applications. 6. Conclusions and Future Work Tree models have been widely adopted in large-scale information retrieval and recommendation tasks due to their logarithmic computational complexity. However, little attention has been paid to the training-testing discrepancy where the retrieval performance deterioration caused by beam search in testing is ignored in training. To the best of our knowledge, we are the first to study this problem on tree models theoretically. We also propose a novel training algorithm for learning optimal tree models under beam search which achieves improved experiment results compared to the state-of-the-arts on both synthetic and real data. For future work, we d like to explore other techniques for training g(x, n) according to Eq. (14), e.g., the REINFORCE algorithm (Williams, 1992; Ranzato et al., 2016) and the actor-critic algorithm (Sutton et al., 2000; Bahdanau et al., 2017). We also want to extend our algorithm for learning tree structure and node-wise scorers jointly. Besides, applying our algorithm to applications like extreme multilabel text classification is also an interesting direction. Acknowledgements We deeply appreciate Xiang Li, Rihan Chen, Daqing Chang, Pengye Zhang, Jie He and Xiaoqiang Zhu for their insightful suggestions and discussions. We thank Huimin Yi, Yang Zheng, Siran Yang, Guowang Zhang, Shuai Li, Yue Song and Di Zhang for implementing the key components of the training platform. We thank Linhao Wang, Yin Yang, Liming Duan and Guan Wang for necessary supports about online serving. We thank anonymous reviewers for their constructive feedback and helpful comments. Bahdanau, D., Brakel, P., Xu, K., Goyal, A., Lowe, R., Pineau, J., Courville, A., and Bengio, Y. An actor-critic algorithm for sequence prediction. In International Conference on Learning Representations, 2017. Cohen, E. and Beck, C. Empirical analysis of beam search performance degradation in neural sequence models. In International Conference on Machine Learning, pp. 1290 1299, 2019. Covington, P., Adams, J., and Sargin, E. Deep neural networks for youtube recommendations. In Proceedings of the 10th ACM conference on recommender systems, pp. 191 198, 2016. Daum e III, H. and Marcu, D. Learning as search optimization: Approximate large margin methods for structured prediction. In International Conference on Machine learning, pp. 169 176. ACM, 2005. Goyal, K., Neubig, G., Dyer, C., and Berg-Kirkpatrick, T. A continuous relaxation of beam search for end-to-end training of neural sequence models. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018. He, R. and Mc Auley, J. Ups and downs: Modeling the visual evolution of fashion trends with one-class collaborative filtering. In proceedings of the 25th international conference on world wide web, pp. 507 517, 2016. Jain, H., Prabhu, Y., and Varma, M. Extreme multi-label loss functions for recommendation, tagging, ranking & other missing label applications. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 935 944, 2016. Jasinska, K., Dembczynski, K., Busa-Fekete, R., Pfannschmidt, K., Klerx, T., and Hullermeier, E. Extreme f-measure maximization using sparse probability estimates. In International Conference on Machine Learning, pp. 1435 1444, 2016. Learning Optimal Tree Models under Beam Search Khandagale, S., Xiao, H., and Babbar, R. Bonsai-diverse and shallow trees for extreme multi-label classification. ar Xiv preprint ar Xiv:1904.08249, 2019. Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015. Kool, W., Van Hoof, H., and Welling, M. Stochastic beams and where to find them: The gumbel-top-k trick for sampling sequences without replacement. In International Conference on Machine Learning, pp. 3499 3508, 2019. Lapin, M., Hein, M., and Schiele, B. Analysis and optimization of loss functions for multiclass, top-k, and multilabel classification. IEEE transactions on pattern analysis and machine intelligence, 40(7):1533 1554, 2017. Mc Auley, J., Targett, C., Shi, Q., and Van Den Hengel, A. Image-based recommendations on styles and substitutes. In Proceedings of the 38th International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 43 52, 2015. Menon, A. K., Rawat, A. S., Reddi, S., and Kumar, S. Multilabel reductions: what is my loss optimising? In Advances in Neural Information Processing Systems, pp. 10599 10610, 2019. Morin, F. and Bengio, Y. Hierarchical probabilistic neural network language model. In Proceedings of the eighth international conference on artificial intelligence and statistics, volume 5, pp. 246 252. Citeseer, 2005. Negrinho, R., Gormley, M., and Gordon, G. J. Learning beam search policies via imitation learning. In Advances in Neural Information Processing Systems, pp. 10652 10661, 2018. Prabhu, Y. and Varma, M. Fastxml: A fast, accurate and stable tree-classifier for extreme multi-label learning. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 263 272, 2014. Prabhu, Y., Kag, A., Harsola, S., Agrawal, R., and Varma, M. Parabel: Partitioned label trees for extreme classification with application to dynamic search advertising. In Proceedings of the 2018 World Wide Web Conference, pp. 993 1002. International World Wide Web Conferences Steering Committee, 2018. Ranzato, M., Chopra, S., Auli, M., and Zaremba, W. Sequence level training with recurrent neural networks. In International Conference on Learning Representations, 2016. Ross, S., Gordon, G., and Bagnell, D. A reduction of imitation learning and structured prediction to no-regret online learning. In Proceedings of the fourteenth international conference on artificial intelligence and statistics, pp. 627 635, 2011. Sarwar, B., Karypis, G., Konstan, J., and Riedl, J. Itembased collaborative filtering recommendation algorithms. In Proceedings of the 10th international conference on World Wide Web, pp. 285 295, 2001. Sutton, R. S., Mc Allester, D. A., Singh, S. P., and Mansour, Y. Policy gradient methods for reinforcement learning with function approximation. In Advances in neural information processing systems, pp. 1057 1063, 2000. Williams, R. J. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229 256, 1992. Wiseman, S. and Rush, A. M. Sequence-to-sequence learning as beam-search optimization. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pp. 1296 1306, 2016. Wu, X.-Z. and Zhou, Z.-H. A unified view of multi-label performance measures. In International Conference on Machine Learning, pp. 3780 3788. JMLR. org, 2017. Wydmuch, M., Jasinska, K., Kuznetsov, M., Busa-Fekete, R., and Dembczynski, K. A no-regret generalization of hierarchical softmax to extreme multi-label classification. In Advances in Neural Information Processing Systems, pp. 6355 6366, 2018. Xu, Y. and Fern, A. On learning linear ranking functions for beam search. In International Conference on Machine learning, pp. 1047 1054, 2007. Yang, F. and Koyejo, S. On the consistency of top-k surrogate losses. ar Xiv preprint ar Xiv:1901.11141, 2019. You, R., Zhang, Z., Wang, Z., Dai, S., Mamitsuka, H., and Zhu, S. Attentionxml: Label tree-based attention-aware deep model for high-performance extreme multi-label text classification. In Advances in Neural Information Processing Systems, pp. 5812 5822, 2019. Zhu, H., Li, X., Zhang, P., Li, G., He, J., Li, H., and Gai, K. Learning tree-based deep model for recommender systems. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 1079 1088. ACM, 2018. Zhu, H., Chang, D., Xu, Z., Zhang, P., Li, X., He, J., Li, H., Xu, J., and Gai, K. Joint optimization of tree-based index and deep model for recommender systems. In Advances in Neural Information Processing Systems, pp. 3973 3982, 2019.