# learning_to_approximate_a_bregman_divergence__5be55792.pdf Learning to Approximate a Bregman Divergence Ali Siahkamari1 Xide Xia2 Venkatesh Saligrama1 David Castañón1 Brian Kulis1,2 1 Department of Electrical and Computer Engineering 2 Department of Computer Science Boston University Boston, MA, 02215 {siaa, xidexia, srv, dac, bkulis}@bu.edu Bregman divergences generalize measures such as the squared Euclidean distance and the KL divergence, and arise throughout many areas of machine learning. In this paper, we focus on the problem of approximating an arbitrary Bregman divergence from supervision, and we provide a well-principled approach to analyzing such approximations. We develop a formulation and algorithm for learning arbitrary Bregman divergences based on approximating their underlying convex generating function via a piecewise linear function. We provide theoretical approximation bounds using our parameterization and show that the generalization error Op(m 1/2) for metric learning using our framework matches the known generalization error in the strictly less general Mahalanobis metric learning setting. We further demonstrate empirically that our method performs well in comparison to existing metric learning methods, particularly for clustering and ranking problems. 1 Introduction Bregman divergences arise frequently in machine learning. They play an important role in clustering [3] and optimization [7], and specific Bregman divergences such as the KL divergence and squared Euclidean distance are fundamental in many areas. Many learning problems require divergences other than Euclidean distances for instance, when requiring a divergence between two distributions and Bregman divergences are natural in such settings. The goal of this paper is to provide a well-principled framework for learning an arbitrary Bregman divergence from supervision. Such Bregman divergences can then be utilized in downstream tasks such as clustering, similarity search, and ranking. A Bregman divergence [7] Dφ : X X R+ is parametrized by a strictly convex function φ : X R such that the divergence of x1 from x2 is defined as the approximation error of the linear approximation of φ(x1) from x2, i.e. Dφ(x1, x2) = φ(x1) φ(x2) φ(x2)T (x1 x2). A significant challenge when attempting to learn an arbitrary Bregman divergences is how to appropriately parameterize the class of convex functions; in our work, we choose to parameterize φ via piecewise linear functions of the form h(x) = maxk [K] a T k x + bk, where [K] denotes the set {1, . . . , K} (see the left plot of Figure 1 for an example). As we discuss later, such max-affine functions can be shown to approximate arbitrary convex functions via precise bounds. Furthermore we prove that the gradient of these functions can approximate the gradient of the convex function that they are approximating, making it a suitable choice for approximating arbitrary Bregman divergences. The key application of our results is a generalization of the Mahalanobis metric learning problem to non-linear metrics. Metric learning is the task of learning a distance metric from supervised data such that the learned metric is tailored to a given task. The training data for a metric learning algorithm is typically either relative comparisons (A is more similar to B than to C) [19, 24, 26] or 34th Conference on Neural Information Processing Systems (Neur IPS 2020), Vancouver, Canada. Figure 1: (Left) Approximating a quadratic function via a max-affine function. (Middle-left) Bregman divergence approximation from every 2-d sample point to the specific point A in the data, as x varies around the circle. x to the specific point A in the data (Middle-right) Switches the roles of x and A (recall the BD is asymmetric) (Right) distances from points x to A using a Mahalanobis distance learned via linear metric learning (ITML). When this BD is used to define a Bregman divergence, points within a given class have a small learned divergence, leading to clustering, k-nn, and ranking performance of 98%+ (see experimental results for details). similar/dissimilar pairs (B and A are similar, B and C are dissimilar) [10]. This supervision may be available when underlying training labels are not directly available, such as from ranking data [20], but can also be obtained directly from class labels in a classification task. In each of these settings, the learned similarity measure can be used downstream as the distance measure in a nearest neighbor algorithm, for similarity-based clustering [3, 19], to perform ranking [23], or other tasks. Existing metric learning approaches are often divided into two classes, namely linear and non-linear methods. Linear methods learn linear mappings and compute distances (usually Euclidean) in the mapped space [10, 26, 11]; this approach is typically referred to as Mahalanobis metric learning. These methods generally yield simple convex optimization problems, can be analyzed theoretically [4, 8], and are applicable in many general scenarios. Non-linear methods, most notably deep metric learning algorithms, can yield superior performance but require a significant amount of data to train and have little to no associated theoretical properties [28, 16]. As Mahlanaobis distances themselves are within the class of Bregman divergences, this paper shows how one can generalize the class of linear methods to encompass a richer class of possible learned divergences, including non-linear divergences, while retaining the strong theoretical guarantees of the linear case. To highlight our main contributions, we Provide an explicit approximation error bound showing that piecewise linear functions can be used to approximate an underlying Bregman divergence with error O(K 1/d) Discuss a generalization error bound for metric learning in the Bregman setting of Op(m 1/2), where m is the number of training points; this matches the bound known for the strictly less general Mahalanobis setting [4] Empirically validate our approach problems of ranking and clustering, showing that our method tends to outperform a wide range of linear and non-linear metric learning baselines. Due to space constraints, many additional details and results have been put into the supplementary material; these include proofs of all bounds, discussion of the regression setting, more details on algorithms, and additional experimental results. 2 Related work To our knowledge, the only existing work on approximating a Bregman divergence is [27], but this work does not provide any statistical guarantees. They assume that the underlying convex function is of the form φ(x) = PN i=1 αih(x T xi), αi 0, where h( ) is a pre-specified convex function such as |z|d. Namely, it is a linear superposition of known convex functions h( ) evaluated on all of the training data. In our preliminary experiments, we have found this assumption to be quite restrictive and falls well short of state-of-art accuracy on benchmark datasets. In contrast to their work, we consider a piecewise linear family of convex functions capable of approximating any convex function. Other relevant non-linear methods include the kernelization of linear methods, as discussed in [19] and [10]; these methods require a particular kernel function and typically do not scale well for large data. Linear metric learning methods find a linear mapping G of the input data and compute (squared) Euclidean distance in the mapped space. This is equivalent to learning a positive semi-definite matrix M = GT G where d M(x1, x2) = (x1 x2)T M(x1 x2) = Gx1 Gx2 2 2. The literature on linear metric learning is quite large and cannot be fully summarized here; see the surveys [19, 5] for an overview of several approaches. One of the prominent approaches in this class is information theoretic metric learning (ITML) [10], which places a Log Det regularizer on M while enforcing similarity/dissimilarity supervisions as hard constraints for the optimization problem. Large-margin nearest neighbor (LMNN) metric learning [26] is another popular Mahalanobis metric learning algorithm tailored for k-nn by using a local neighborhood loss function which encourages similarly labeled data points to be close in each neighborhood while leaving the dissimilar labeled data points away from the local neighborhood. In Schultz and Joachims [24], the authors use pairwise similarity comparisons (B is more similar to A than to C) by minimizing a margin loss. 3 Problem Formulation and Approach We now turn to the general problem formulation considered in this paper. Suppose we observe data points X = [x1, ..., xn], where each xi Rd. The goal is to learn an appropriate divergence measure for pairs of data points xi and xj, given appropriate supervision. The class of divergences considered here is Bregman divergences; recall that Bregman divergences are parameterized by a continuously differentiable, strictly convex function φ : Ω R, where Ωis a closed convex set. The Bregman divergence associated with φ is defined as Dφ(xi, xj) = φ(xi) φ(xj) φ(xj)T (xi xj). Examples include the squared Euclidean distance (when φ(x) = x 2 2), the Mahalanobis distance, and the KL divergence. Learning a Bregman divergence can be equivalently described as learning the underlying convex function for the divergence. In order to fully specify the learning problem, we must determine both a supervised loss function as well as a method for appropriately parameterizing the convex function to be learned. Below, we describe both of these components. 3.1 Loss Functions We can easily generalize the standard empirical risk minimization framework for metric learning, as discussed in [19], to our more general setting. In particular, suppose we have supervision in the form of m loss functions ℓt; these ℓt depend on the learned Bregman divergence parameterized by φ as well as the data points X and some corresponding supervision y. We can express a general loss function as t=1 ℓt(Dφ, X, y) + λr(φ), where r is a regularizer over the convex function φ, λ is a hyperparameter that controls the tradeoff between the loss and the regularizer, and the supervised losses ℓt are assumed to be a function of the Bregman divergence corresponding to φ. The goal in an empirical risk minimization framework is to find φ to minimize this loss, i.e., minφ F L(φ), where F is the set of convex functions over which we are optimizing. The above general loss can capture several learning problems. For instance, one can capture a regression setting, e.g., when the loss ℓt is the squared loss between the true Bregman divergence and the divergence given by the approximation. In the metric learning setting, one can utilize a loss function ℓt such as a triplet or contrastive loss, as is standard. In our experiments and generalization error analysis, we mainly consider a generalization of the triplet loss, where the loss is max(0, α + Dφ(xit, xjt) Dφ(xkt, xlt)) for a tuple (xit, xjt, xkt, xlt); see Section 3.3 for details. 3.2 Convex piecewise linear fitting Next we must appropriately parameterize φ. We choose to parameterize our Bregman divergences using piecewise linear approximations. Piecewise linear functions are used in many different applications such as global optimization [22], circuit modeling [17, 14] and convex regression [6, 2]. There are many methods for fitting piecewise linear functions including using neural networks [12] and local linear fits on adaptive selected partitions of the data [15]; however, we are interested in formulating a convex optimization problem as done in [21]. We use convex piecewise linear functions of the form FP,L = {h : Ω R | h(x) = maxk [K] a T k x + bk , ak 1 L}, called max-affine functions. In our notation [K] denotes the set {1, . . . , K}. See the left plot of Figure 1 for a visualization of using a max-affine function. We stress that our goal is to approximate Bregman divergences, and as such strict convexity and differentiability are not required of the class of approximators when approximating an arbitrary Bregman divergence. Indeed, it is standard practice in learning theory to approximate a class of functions within a more tractable class. In particular, the use of piecewise linear functions has precedence in function approximation, and has been used extensively for approximating convex functions (e.g. [1]). Conventional numerical schemes seek to approximate a function as a linear superposition of fixed basis functions (eg. Bernstein polynomials). Our method could be directly extended to such basis functions and can be kernelized as well. Still, piecewise linear functions offer a benefit over linear superpositions. The max operator acts as a ridge function resulting in significantly richer non-linear approximations. In the next section we will discuss how to formulate optimization over FP,L in order to solve the loss function described earlier. In particular, the following lemma will allow us to express appropriate optimization problems using linear inequality constraints: Lemma 1. [6] There exists a convex function φ : Rd R, that takes values φ(xi) = zi (1) if and only if there exists a1, . . . , an Rd such that zi zj a T j (xi xj), i, j [n]. (2) Proof. Assuming such φ exists, take aj to be any sub-gradient of φ(xj) then (2) holds by convexity. Conversely, assuming (2) holds, define φ as φ(x) = max i [n] a T i (x xi) + zi. (3) φ is convex due to the proposed function being a max of linear functions. φ(xi) = bi using (2). As a direct consequence of Lemma 1, one can see that a Bregman divergence can take values Dφ(xi, xj) = zi zj a T j (xi xj), (4) if and only if conditions in (2) hold. A key question is whether piecewise linear functions can be used to approximate Bregman divergences well enough. An existing result in [1] says that for any L-Lipschitz convex function φ there exists a piecewise linear function h FP,L such that φ h 36LRK 2 d , where K is the number of hyperplanes and R is the radius of the input space. However, this existing result is not directly applicable to us since a Bregman divergence utilizes the gradient φ of the convex function. As a result, in section 3.4, we bound the gradient error φ h of such approximators. This in turn allows us to prove a result demonstrating that we can approximate Bregman divergences with arbitrary accuracy under some regularity conditions. 3.3 Metric Learning Algorithm We now briefly discuss algorithms for solving the underlying loss functions described in the previous section. A standard metric learning scenario considers the case where the supervision is given as relative comparisons between objects. Suppose we observe Sm = {(xit, xjt, xkt, xlt) | t [m]}, where D(xit, xjt) D(xkt, xlt) for some unknown similarity function and (it, jt, kt, lt) are indices of objects in a countable set U (e.g. set of people in a social network). To model the underlying similarity function D, we propose a Bregman divergence of the form: ˆD(xi, xj) ˆφ(xi) ˆφ(xj) ˆφ(xj), (5) where ˆφ(x) maxi Um a T i (x xi) + zi , is the biggest sub-gradient, Um is the set of all observed objects indices Um m t=1{it, jt, kt, lt} and ai s and zi s are the solution to the following linear program: min zi,ai,L t=1 max(ζt, 0) + λL zit zjt a T jt(xit xjt) + zlt zkt + a T lt(xkt xlt) ζt 1 t [m] zi zj a T j (xi xj) i, j Um ai 1 L i Um We refer to the solution of this optimization problem as PBDL (piecewise Bregman divergence learning). Note that one may consider other forms of supervision, such as pairwise similarity constraints, and these can be handled in an analogous manner. Also, the above algorithm is presented for readability for the case where K = n; the case where K < n is discussed in the supplementary material. In order to scale our method to large datasets, there are several possible approaches. One could employ ADMM to the above LP, which can be implemented in a distributed fashion or on GPUs. 3.4 Analysis Now we present an analysis of our approach. Due to space considerations, proofs appear in the supplementary material. Briefly, our results: i) show that a Bregman divergence parameterized by a piecewise linear convex function can approximate an arbitrary Bregman divergence with error O(K 1 d ), where K is the number of affine functions; ii) bound the Rademacher complexity of the class of Bregman divergences parameterized by piecewise linear generating functions; iii) provide a generalization for Bregman metric learning that shows that the generalization error gap grows as Op(m 1 2 ), where m is the number of training points. In the supplementary material, we further provide additional generalization guarantees for learning Bregman divergences in the regression setting. In particular, it is worth noting that, in the regression setting, we provide a generalization bound of Op(m 1/(d+2)), which is comparable to the lowerbound for convex regression Op(m 4/(d+4)). Approximation Guarantees. First we would like to bound how well one can approximate an arbitrary Bregman divergence when using a piecewise linear convex function. Besides providing a quantitative justification for using such generating functions, this result is also used for later generalization bounds. Theorem 1. For any convex φ : Ω R, which: 1) is defined on the -norm ball, i.e: B(R) = {x Rd, x R} Ω 2) is β-smooth, i.e: φ(x) φ(y) 1 β x y . There exists a max-affine function h with K hyper-planes such that: 1) it uniformly approximates φ: sup x B(R) |φ(x) h(x)| 4βR2K 2/d. (6) 2) Any of its sub-gradients h(x) h(x) away from boundaries of the norm ball, uniformly approximates φ(x). sup x B(R ϵ) φ(x) h(x) 1 16βRK 1/d, (7) 3) The Bregman divergence parameterized by h away from boundaries of the norm ball, uniformly approximates Bregman divergence parameterized by φ sup x,x B(R ϵ) |Dφ(x, x ) Dh(x, x )| 36βR2K 1/d, (8) Rademacher Complexity. Another result we require for proving generalization error is the Rademacher complexity of the class of Bregman divergences using our choice of generating functions. We have the following result: Lemma 2. The Radamacher complexity of Bregman divergences parameterized by max-affine functions, Rm(DP,L), is bounded by Rm(DP,L) 4KLR p 2 ln(2d + 2)/m. Generalization Error. Finally, we consider the case of classification error when learning a Bregman divergence under relative similarity constraints. Our result bounds the loss on unseen data based on the loss on the training data. We require that the training data be drawn iid. Note that while there are known methods to relax these assumptions, as shown for Mahalanobis metric learning in [4], we assume here for simplicity that data is drawn iid.1 In particular, we assume that each instance is a quintuple, consisting of two pairs (xit, xjt, xkt, xlt) drawn iid from some distribution µ over X 4. Theorem 2. Consider Sm = {(xit, xjt, xkt, xlt), t [m]} µm, where D(xit, xjt) D(xkt, xlt). Set R = maxi xi . The generalization error of the learned divergence in (1) when using K hyper-planes satisfies E 1[ ˆD(xit, xjt) ˆD(xkt, xlt)] 1 t=1 max 0, 1 + ˆD(xit, xjt) ˆD(xkt, xlt) 2 ln (2d + 2)/ m 4 ln(4 log2 L) + ln (1/δ)/ m with probability at least 1 δ for receiving the data Sm. See the supplementary material for a proof. Discussion of Theorem 2: Not that n stands for number of unique points in all comparisons, where m stands for number of comparisons, i.e: n = #Um, so n will increase with m. case 1: (K