# continuous_wasserstein2_barycenter_estimation_without_minimax_optimization__4b52694e.pdf Published as a conference paper at ICLR 2021 CONTINUOUS WASSERSTEIN-2 BARYCENTER ESTIMATION WITHOUT MINIMAX OPTIMIZATION Alexander Korotin Skolkovo Institute of Science and Technology Advanced Data Analytics in Science and Engineering Group Moscow, Russia a.korotin@skoltech.ru Lingxiao Li Massachusetts Institute of Technology Geometric Data Processing Group Cambridge, Massachusetts, USA lingxiao@mit.edu Justin Solomon Massachusetts Institute of Technology Geometric Data Processing Group Cambridge, Massachusetts, USA jsolomon@mit.edu Evgeny Burnaev Skolkovo Institute of Science and Technology Advanced Data Analytics in Science and Engineering Group Moscow, Russia e.burnaev@skoltech.ru Wasserstein barycenters provide a geometric notion of the weighted average of probability measures based on optimal transport. In this paper, we present a scalable algorithm to compute Wasserstein-2 barycenters given sample access to the input measures, which are not restricted to being discrete. While past approaches rely on entropic or quadratic regularization, we employ input convex neural networks and cycle-consistency regularization to avoid introducing bias. As a result, our approach does not resort to minimax optimization. We provide theoretical analysis on error bounds as well as empirical evidence of the effectiveness of the proposed approach in low-dimensional qualitative scenarios and high-dimensional quantitative experiments. 1 INTRODUCTION Wasserstein barycenters have become popular due to their ability to represent the average of probability measures in a geometrically meaningful way. Techniques for computing Wasserstein barycenters have been successfully applied to many computational problems. In image processing, Wasserstein barycenters are used for color and style transfer (Rabin et al., 2014; Mroueh, 2019), and texture synthesis (Rabin et al., 2011). In geometry processing, shape interpolation can be done by computing barycenters (Solomon et al., 2015). In online machine learning, barycenters are used for aggregating probabilistic predictions of experts (Korotin et al., 2019b). Within the context of Bayesian inference, the barycenter of subset posteriors converges to the full data posterior, thus enabling efficient computational methods based on finding the barycenters (Srivastava et al., 2015; 2018). Fast and accurate barycenter algorithms exist for discrete distributions (see Peyré et al. (2019) for a survey), while for continuous distributions the situation is more difficult and remains unexplored until recently (Li et al., 2020; Fan et al., 2020; Cohen et al., 2020). The discrete methods scale poorly with the number of support points of the barycenter and thus cannot approximate continuous barycenters well, especially in high dimensions. In this paper, we present a method to compute Wasserstein-2 barycenters of continuous distributions based on a novel regularized dual formulation where the convex potentials are parameterized by input convex neural networks (Amos et al., 2017). Our algorithm is straightforward without introducing bias (e.g. Li et al. (2020)) or requiring minimax optimization (e.g. Fan et al. (2020)). This is made possible by combining a new congruence regularizing term combined with cycle-consistency regularization (Korotin et al., 2019a). As we will show in the analysis, thanks to the properties of Published as a conference paper at ICLR 2021 Wasserstein-2 distances, the gradients of the resulting convex potentials push the input distributions close to the true barycenter, allowing good approximation of the barycenter. 2 PRELIMINARIES We denote the set of all Borel probability measures on RD with finite second moment by P2(RD). We use P2,ac(RD) P2(RD) to denote the subset of all absolutely continuous measures (w.r.t. the Lebesgue measure). Wasserstein-2 distance. For P, Q P2(RD), the Wasserstein-2 distance is defined by W2 2(P, Q) def = min π Π(P,Q) RD RD x y 2 2 dπ(x, y), (1) where Π(P, Q) is the set of probability measures on RD RD whose marginals are P, Q, respectively. This definition is known as Kantorovich s primal form of transport distance (Kantorovitch, 1958). The Wasserstein-2 distance W2 is well-studied in the theory of optimal transport (Brenier, 1991; Mc Cann et al., 1995). In particular, it has a dual formulation (Villani, 2003): W2 2(P, Q) = Z 2 d P(x) + Z 2 d Q(y) min ψ Conv RD ψ(x)d P(x) + Z RD ψ(y)d Q(y) , (2) where the minimum is taken over all the convex functions (potentials) ψ : RD R { }, and ψ(y) = maxx RD x, y ψ(x) : RD R { } is the convex conjugate of ψ (Fenchel, 1949), which is also a convex function. The optimal potential ψ is defined up to an additive constant. Brenier (1991) shows that if P does not give mass to sets of dimensions at most D 1, then the optimal plan π is uniquely determined by π = [id RD, T ] P, where T : RD RD is the unique solution to the Monge s problem T = arg min T P=Q RD x T(x) 2 2 d P(x). (3) The connection between T and the dual formulation (2) is that T = ψ , where ψ is the optimal solution of (2). Additionally, if Q does not give mass to sets of dimensions at most D 1, then T is invertible and T (x) = ψ (x) = ( ψ ) 1(x), (T ) 1(y) = ψ (y) = ( ψ ) 1(y). In particular, the above discussion applies to the case where P, Q P2,ac(RD). Wasserstein-2 barycenter. Let P1, . . . , PN P2,ac(RD). Then, their barycenter w.r.t. weights α1, . . . , αN (αn > 0 and PN n=1 αn = 1) is P def = arg min P P2(RD) n=1 αn W2 2(Pn, P). (4) Throughout this paper, we assume that at least one of P1, . . . , PN P2,ac(RD) has bounded density. Under this assumption, P is unique and absolutely continuous, i.e., P P2,ac(RD), and it has bounded density (Agueh & Carlier, 2011, Definition 3.6 & Theorem 5.1). For n {1, 2, . . . , N}, let (ψ n, ψ n) be the optimal pair of (mutually) conjugate potentials that transport Pn to P, i.e., ψ n Pn = P and ψ n P = Pn. Then {ψ n} satisfy n=1 αn ψ n(x) = x and n=1 αnψ n(x) = x 2 for all x RD (Agueh & Carlier, 2011; Álvarez-Esteban et al., 2016). Since optimal potentials are defined up to a constant, for convenience, we set c = 0. The condition (5) serves as the basis for our algorithm for computing Wasserstein-2 barycenters. We say that potentials ψ1, . . . , ψN are congruent w.r.t. weights α1, . . . , αn if their conjugate potentials satisfy (5), i.e., PD n=1 αnψn(x) = x 2 2 for all x RD. Published as a conference paper at ICLR 2021 3 RELATED WORK Most algorithms in the field of computational optimal transport are designed for the discrete setting where the input distributions have finite support; see the recent survey by Peyré et al. (2019) for discussion. A particular popular line of algorithms are based on entropic regularization that gives rise to the famous Sinkhorn iteration (Cuturi, 2013; Cuturi & Doucet, 2014). These methods are typically limited to a support of 105 106 points before the problem becomes computationally infeasible. Similarly, discrete barycenter methods (Cuturi & Doucet, 2014), particularly the ones that rely on a fixed support for the barycenter (Dvurechenskii et al., 2018; Staib et al., 2017), cannot provide precise approximation of continuous barycenters in high dimensions, since a large number of samples is needed; see experiments in Fan et al. (2020, 4.3) for an example. Thus we focus on the existing literature in the continuous setting. Computation of Wasserstein-2 distances and maps. Genevay et al. (2016) demonstrate the possibility of computing Wasserstein distances given only sample access to the distributions by parameterizing the dual potentials as functions in the reproducing kernel Hilbert spaces. Based on this realization, Seguy et al. (2017) propose a similar method but use neural networks to parameterize the potentials, using entropic or L2 regularization w.r.t. P Q to keep the potentials approximately conjugate. The transport map is recovered from optimized potentials via barycentric projection. As we note in 2, W2 enjoys many useful theoretical properties. For example, the optimal potential ψ is convex, and the corresponding optimal transport map is given by ψ . By exploiting these properties, Makkuva et al. (2019) propose a minimax optimization algorithm for recovering transport maps, using input convex neural networks (ICNNs) (Amos et al., 2017) to approximate the potentials. An alternative to entropic regularization is the cycle-consistency regularization proposed by Korotin et al. (2019a). It uses the property that the gradients of optimal dual potentials are inverses of each other. The imposed regularizer requires integration only over the marginal measures P and Q, instead of over P Q as required by entropy-based alternatives. Their method converges faster than the minimax method since it does not have an inner optimization cycle. Xie et al. (2019) propose using two generative models with a shared latent space to implicitly compute the optimal transport correspondence between P and Q. Based on the obtained correspondence, the authors are able to compute the optimal transport distance between the distributions. Computation of Wasserstein-2 barycenters. A few recent techniques tackle the barycenter problem (4) using continuous rather than discrete approximations of the barycenter: MEASURE-BASED (GENERATIVE) OPTIMIZATION: Problem (4) optimizes over probability measures. This can be done using the generic algorithm by Cohen et al. (2020) who employ generative networks to compute barycenters w.r.t. arbitrary discrepancies. They test their method with the maximum mean discrepancy (MMD) and Sinkhorn divergence. This approach suffers from the usual limitations of generative models such as mode collapse. Applying it to W2 barycenters requires estimation of W2 2(Pn, P). Fan et al. (2020) test this approach using the minimax method by Makkuva et al. (2019), but they end up with a challenging min-max-min problem. POTENTIAL-BASED OPTIMIZATION: Li et al. (2020) recover the optimal potentials {ψ n} via a non-minimax regularized dual formulation. No generative model is needed: the barycenter is recovered by pushing forward measures using gradients of potentials or by barycentric projection. Inspired by Li et al. (2020) we use a potential-based approach and recover the barycenter by using gradients of the potentials as pushforward maps. The main differences are: (1) we restrict the potentials to be convex, (2) we enforce congruence via a regularizing term, and (3) our formulation does not introduce bias, meaning the optimal solution of our formulation gives the true barycenter. Published as a conference paper at ICLR 2021 4.1 DERIVING THE DUAL PROBLEM Let P be the true barycenter. Our goal is to recover the optimal potentials {ψ n, ψ n} mapping the input measures Pn into P. To start, we express the barycenter objective (4) after substituting the dual formulation (2): n=1 αn W2 2(Pn, P) = N X 2 d Pn(x) + Z min {ψn} Conv RD ψn(x)d Pn(x) + RD ψn(y)d P(y) The minimum is attained not just among convex potentials {ψn}, but among congruent potentials (see discussion under (5)); thus, we can add the constraint that {ψn} are congruent to (6). Hence, n=1 αn W2 2(Pn, P) = N X 2 d Pn(x) min {ψn} congruent RDψn(y)d Pn(y) | {z } Multi Corr({αn,Pn}|{ψn}) To transition from (6) to (7), we used the fact that for congruent {ψn} we have PN n=1 αnψn(x) = x 2 2 , so PN n=1 R RD αnψn(y)d P(y) = R We call the value inside the minimum in (7) the multiple correlation of {Pn} with weights {αn} w.r.t. potentials {ψn}. Notice that the true barycenter P appears nowhere on the right side of (7). Thus the optimal potentials {ψ n} can be recovered by solving the following min {ψn} congruent Multi Corr({αn, Pn}|{ψn}) = min {ψn} congruent RDψn(y)d Pn(y) . (8) 4.2 IMPOSING THE CONGRUENCE CONDITION It is challenging to impose the congruence condition on convex potentials. What if we relax the congruence condition? The following theorem bounds how close a set of convex potentials {ψn} is to {ψ n} in terms of the difference of multiple correlation. Theorem 4.1. Let P P2,ac(RD) be the barycenter of P1, . . . , PN P2,ac(RD) w.r.t. weights α1, . . . , αN. Let {ψ n} be the optimal congruent potentials of the barycenter problem. Suppose we have B-smooth1 convex potentials {ψn} for some B [0, + ], and denote = Multi Corr({αn, Pn} | {ψn}) Multi Corr({αn, Pn} | {ψ n}). Then, αnψn(y) y 2 | {z } Congruence mismatch n=1 αn ψ n(x) ψn(x) 2 Pn. (9) Here µ denotes the norm induced by inner product in Hilbert space L2(RD RD, µ). We call the second term on the left of (9) the congruence mismatch. We prove this in Appendix B. Note that if the congruence mismatch is non-positive, then n=1 αn ψ n(x) ψn(x) 2 Pn 1 n=1 αn W2 2( ψn Pn, P), (10) where the last inequality of (10) follows from (Korotin et al., 2019a, Lemma A.2). From (10), we conclude that for all n {1, . . . , N}, we have W2 2( ψn Pn, P) B αn . This shows that if the congruence mismatch is non-positive, then , the difference in multiple correlation, provides 1We say that a diffirentiable function f : RD R is B-smooth if its gradient f is B-Lipschitz. Published as a conference paper at ICLR 2021 an upper bound for the Wasserstein-2 distance between the true barycenter and each pushforward ψn Pn. This justifies the use of ψn Pn to recover the barycenter. Notice for optimal potentials, the congruence mismatch is zero. Thus to penalize positive congruence mismatch, we introduce a regularizing term RP 1({αn}, {ψn}) def = Z n=1 αnψn(y) y 2 + d P(y). (11) Because we take the positive part of the integrand of (9) to get (11) and that the right side of (9) is non-negative, we have Multi Corr({αn, Pn} | {ψn}) + 1 RP 1({αn}, {ψn}) Multi Corr({αn, Pn} | {ψ n}) 0 for all convex potentials {ψn}. On the other hand, for optimal potentials {ψn} = {ψ n}, the inequality turns into equality, implying that adding the regularizing term 1 RP 1({αn}, {ψn}) to (8) will not introduce bias the optimal solution still yields {ψ n}. However, evaluating (11) exactly requires knowing the true barycenter P a priori. To remedy this issue, one may replace P with another absolutely continuous measure τ b P (τ 1 and b P is a probability measure) whose density bounds that of P from above almost everywhere. In this case, τ R b P 1({αn}, {ψn}) = τ Z n=1 αnψn(y) y 2 +db P RP 1({αn}, {ψn}). (12) Hence we obtain the following regularized version of (8) where {ψ n} is the optimal solution: min {ψn} Conv Multi Corr({αn, Pn} | {ψn}) + τ R b P 1({αn}, {ψn}) . (13) Selecting a measure τ b P is not obvious. Consider the case when {Pn} are supported on compact sets X1, . . . , XN RD and P1 has density upper bounded by h < . In this scenario, the barycenter density is upper bounded by h α D 1 (Álvarez-Esteban et al., 2016, Remark 3.2). Thus, the measure τ b P supported on Convex Hull(X1, . . . , XN) with this density is an upper bound for P. We will address the question of how to choose τ, b P properly in practice in 4.4. 4.3 ENFORCING CONJUGACY OF POTENTIALS PAIRS Throughout this subsection, we assume the upper bound finite measure τ b P of the P is known. The optimization problem (13) involves not only the potentials {ψn}, but also their conjugates {ψn}. This brings practical difficulty since evaluating conjugate potentials is hard (Korotin et al., 2019a). Instead we parameterize potentials ψn and ψn separately using input convex neural networks (ICNN) as ψ n and ψ n respectively. We add an additional cycle-consistency regularizer to enfore the conjugacy of ψ n and ψ n as in Korotin et al. (2019a). This regularizer is defined as RPn 2 (ψ n, ψ n) def = Z RD ψ n ψ n(x) x 2 2 d Pn(x) = ψ n ψ n id RD 2 Pn. Note that RPn 2 (ψ n, ψ n) = 0 this condition is necessary for ψ n and ψ n to be conjugate with each other. Also, it is a sufficient condition for convex functions to be conjugates up to an additive constant. We use one-sided regularization. In our case, computing the regularizer of the other direction ψ n ψ n id RD 2 P is infeasible, since P is unknown. If fact, Korotin et al. (2019a) demonstrates that such one-sided condition is sufficient. Published as a conference paper at ICLR 2021 In this way we use 2N input convex neural networks for {ψ n, ψ n}. By adding the new cycle consistency regularizer into (13), we obtain our final objective: min {ψ n,ψ n} Approximate multiple correlation z }| { N X RD [ x, ψ n(x) ψ n( ψ n(x)) | {z } ]d Pn(x) +τ R b P 1({ψ n}) | {z } Congruence reg. n=1 αn RPn 2 (ψ n, ψ n) | {z } Cycle regularizer Note that we express the aproximate multiple correlation by using both potentials {ψ n} and {ψ n}. This is done to eliminate the freedom of an additive constant on {ψ n} that is not addressed by cycle regularization. We denote the entire objective as Multi Corr {Pn} | {ψ }, {ψ }; τ, b P, λ . Analogous to Theorem 4.1, we have following result showing that this new objective enjoys the same properties as the unregularized version from (8). Theorem 4.2. Let P P2,ac(RD) be the barycenter of P1, . . . , PN P2,ac(RD) w.r.t. weights α1, . . . , αN. Let {ψ n} be the optimal congruent potentials of the barycenter problem. Suppose we have τ, ˆP such that τ 1 and τ b P P. Suppose we have convex potentials {ψ n} and β -strongly convex and B -smooth convex potentials {ψ n} with 0 < β B < and λ > B 2(β )2 . Then Multi Corr {αn, Pn} | {ψ n}, {ψ n}; τ, b P, λ Multi Corr {αn, Pn} | {ψ n} . (15) Denote = Multi Corr {αn, Pn} | {ψ n}, {ψ n}; τ, b P, λ Multi Corr {αn, Pn} | {ψ n} . Then for all n {1, . . . , N}, we have W2 2 ψ n Pn, P 2 = O( ). (16) Informally, Theorem 4.2 states that the better we solve the regularized dual problem, (14) the closer we expect each ψ n Pn to be to the true barycenter P in W2. It follows from (15) that our final objective (14) is unbiased: the optimal solution is obtained by {ψ n, ψ n}. 4.4 PRACTICAL ASPECTS AND OPTIMIZATION PROCEDURE In practice, even if the choice of τ, b P does not satisfy τ b P P, we observe the pushforward measures ψ n Pn often converge to P. To partially bridge the gap between theory and practice, we dynamically update the measure b P so that after each optimization step we set (for γ [0, 1]) b P := γ b P + (1 γ) n=1 αn ψ Pn , i.e., the probability measure b P is a mixture of the given initial measure b P and the current barycenter estimates { ψ Pn}. For the initial b P one may use the barycenter of {N(µPn, ΣPn)}. It can be efficiently computed via an iterative fixed point algorithm (Álvarez-Esteban et al., 2016; Chewi et al., 2020). During the optimization, these estimates become closer to the true barycenter and can thus improve the congruence regularizer (12). We use mini-batch stochastic gradient descent to solve (14) where the integration is done by Monte Carlo sampling from input measures {Pn} and regularization measure b P, similar to Li et al. (2020). We provide the detailed optimization procedure (Algorithm 1) and discuss its computational complexity in Appendix A. In Appendix C.3, we demonstrate that the impact of the considered regularization on our model: we show that cycle consistency and the congruence condition of the potentials are well satisfied. 5 EXPERIMENTS The code is written on Py Torch framework and is publicly available at Published as a conference paper at ICLR 2021 https://github.com/iamalexkorotin/Wasserstein2Barycenters. We compare our method [CW2B] with the potential-based method [CRWB] by Li et al. (2020) (with Wasserstein-2 distance and L2-regularization) and with the measure-based generative method [SCW2B] by Fan et al. (2020). All considered methods recover 2N potentials {ψ n, ψ n} {ψ n, ψ n} and approximate the barycenter as pushforward measures { ψ n Pn}. Regularization in [CRWB] allows access to the joint density of the transport plan, a feature of their method that we do not consider here. The method [SCW2B] additionally outputs a generated barycenter g S P where g is the generative network and S is the input noise distribution. To assess the quality of the computed barycenter, we consider the unexplained variance percentage defined as UVP( P) = 100 W2 2( P,P) 1/2Var(P)%. When UVP 0%, P is a good approximation of P. For values 100%, the distribution P is undesirable: a trivial baseline P0 = δEP[y] achieves UVP(P0) = 100%. Evaluating UVP in high dimensions is infeasible: empirical estimates of W2 2 are unreliable due to high sample complexity (Weed et al., 2019). To overcome this issue, for barycenters given by ψ n Pn we use L2-UVP defined by L2-UVP( ψ n, Pn) def = 100 ψ n ψ n 2 Pn Var(P) % UVP( ψ n Pn) , (17) where the inequality in brackets follows from (Korotin et al., 2019a, Lemma A.2). We report the weighted average of L2-UVP of all pushforward measures w.r.t. the weights αn. For barycenters given in an implicit form g S, we compute the Bures-Wasserstein UVP defined by BW2 2-UVP(g S) def = 100BW2 2(g S, P) 1 2Var(P) % UVP(g S) , (18) where BW2 2(P, Q) = W2 2 N(µP, ΣP), N(µQ, ΣQ) is the Bures-Wasserstein metric and we use µP, ΣP to denote the mean and the covariance of a distribution P (Chewi et al., 2020). It is known that BW2 2 lower-bounds W2 2 (Dowson & Landau, 1982), so the inequality in the brackets of (18) follows. A detailed discussion of the adopted metrics is given in Appendix C.2. 5.1 HIGH-DIMENSIONAL LOCATION-SCATTER EXPERIMENTS (a) Input distributions {Pn} (b) True barycenter P (c) SCW2B, generated distribution g S (d) SCW2B, distributions ψ n Pn (e) CRWB, distributions ψ n Pn (f) CW2B, distributions ψ n Pn Figure 1: Barycenter of location-scatter Swiss roll population computed by three methods. Published as a conference paper at ICLR 2021 In this section, we consider N = 4 with (α1, . . . , α4) = (0.1, 0.2, 0.3, 0.4) as weights. We consider the location-scatter family of distributions (Álvarez-Esteban et al., 2016, 4) whose true barycenter can be computed. Let P0 P2,ac and define the following location-scatter family of distributions F(P0) = {f S,u P0 | S M+ D D, u RD}, where f S,u : RD RD is a linear map f S,u(x) = Sx + u with positive definite matrix S M+ D D. When {Pn} F(P0), their barycenter P is also an element of F(P0) and can be computed via fixed-point iterations (Álvarez-Esteban et al., 2016). Figure 1a shows a 2-dimensional location-scatter family generated by using the Swiss roll distribution as P0. The true barycenter is shown in Figure 1b. The generated barycenter g S of [SCW2B] is given in Figure 1c. The pushforward measures ψ n Pn of each method are provided in Figures 1d, 1e, 1f, respectively. In this example, the pushforward measures ψn Pn all reasonably approximate P, whereas the generated barycenter g S of [SCW2B] (Figure 1c) visibly underfits. For quantitative comparison, we consider two choices for P0: the D-dimensional standard Gaussian distribution and the uniform distribution on [ 3]D. Each Pn is constructed as f ST n ΛSn,0 P0 F(P0), where Sn is a random rotation matrix and Λ is diagonal with entries [ 1 2b1, . . . , 2] where b = D 1 4. We consider only centered distributions (i.e. zero mean) because the barycenter of non-centered {Pn} P2,ac(RD) is the barycenter of {P n} shifted by PN n=1 αnµPn, where {P n} are centered copies of {Pn} (Álvarez-Esteban et al., 2016). Results are shown in Table 1 and 2. In these experiments, our method outperforms [CRWB] and [SCW2B]. For [CRWB], dimension 16 is the breakpoint: the method does not scale well to higher dimensions. [SCW2B] scales with the increasing dimension better, but its errors L2-UVP and BW2 2-UVP are twice as high as ours. This is likely due to the generative approximation and the difficult min-max-min optimization in [SCW2B]. For completeness, we also compare our algorithm to the proposed in Cuturi & Doucet (2014) which approximates the barycenter by a discrete distribution on a fixed number of free-support points. In our experiment, similar to Li et al. (2020), we set 5000 as the support size. As expected, the BW2 2-UVP error of the method increases drastically as the dimension grows and the method is outperformed by our approach. To show the scalability of our method with the number of input distributions N, we conduct an analogous experiment with a high-dimensional location-scatter family for N = 20. We set αn = 2n N(N+1) for n = 1, 2, ..., 20 and choose the uniform distribution on [ 3]D as P0 and construct distributions Pn F(P0) as before. The results for dimensions 32, 64 and 128 are provided in Table 3. Similar to the results from Tables 1 and 2, we see that our method outperforms the alternatives. Metric Method D=2 4 8 16 32 64 128 256 BW2 2-UVP, % [FCWB], Cuturi & Doucet (2014) 0.7 0.68 1.41 3.87 8.85 14.08 18.11 21.33 [SCW2B], (Fan et al., 2020) 0.07 0.09 0.16 0.28 0.43 0.59 1.28 2.85 L2-UVP, % (potentials) 0.08 0.10 0.17 0.29 0.47 0.63 1.14 1.50 [CRWB], (Li et al., 2020) 0.99 2.52 8.62 22.23 67.01 >100 [CW2B], ours 0.06 0.05 0.07 0.11 0.19 0.24 0.42 0.83 Table 1: Comparison of UVP for the case {Pn} F(P0), P0 = N(0, ID), N = 4. Metric Method D=2 4 8 16 32 64 128 256 BW2 2-UVP, % [FCWB], Cuturi & Doucet (2014) 0.64 0.77 1.22 3.75 8.92 14.3 18.46 21.64 [SCW2B], (Fan et al., 2020) 0.12 0.10 0.19 0.29 0.46 0.6 1.38 2.9 L2-UVP, % (potentials) 0.17 0.12 0.2 0.31 0.47 0.62 1.21 1.52 [CRWB], (Li et al., 2020) 0.58 1.83 8.09 21.23 55.17 > 100 [CW2B], ours 0.17 0.08 0.06 0.1 0.2 0.25 0.42 0.82 Table 2: Comparison of UVP for the case {Pn} F(P0), P0 = Uniform [ 3]D , N = 4. Metric Method D=32 64 128 BW2 2-UVP, % [FCWB], Cuturi & Doucet (2014) 14.09 26.21 38.43 [SCW2B], (Fan et al., 2020) 0.62 0.93 1.83 L2-UVP, % (potentials) 0.60 0.86 1.52 [CW2B], ours 0.31 0.58 1.45 Table 3: Comparison of UVP for the case {Pn} F(P0), P0 = Uniform [ 3]D , N = 20. 5.2 SUBSET POSTERIOR AGGREGATION We apply our method to aggregate subset posterior distributions. The barycenter of subset posteriors converges to the true posterior (Srivastava et al., 2018). Thus, computing the barycenter of subset Published as a conference paper at ICLR 2021 posteriors is an efficient alternative to obtaining a full posterior in the big data setting (Srivastava et al., 2015; Staib et al., 2017; Li et al., 2020). Analogous to (Li et al., 2020), we consider Poisson and negative binomial regressions for predicting the hourly number of bike rentals using features such as the day of the week and weather conditions.2 We consider the posterior on the 8-dimensional regression coefficients for both Poisson and negative binomial regressions. We randomly split the data into N = 5 equally-sized subsets and obtain 105 samples from each subset posterior using the Stan library (Carpenter et al., 2017). This gives the discrete uniform distributions {Pn} supported on the samples. As the ground truth barycenter P, we consider the full dataset posterior also consisting of 105 points. We use BW2 2-UVP( P, P) to compare the estimated barycenter P (pushforward measure ψ n Pn or generated measure g S) with the true barycenter. The results are in Table 4. All considered methods perform well (UVP< 2%), but our method outperforms the alternatives. Regression SCW2B, (Fan et al., 2020) [CRWB], (Li et al., 2020) CW2B, ours P = g S P = ψn Pn BW2 2-UVP, % Poisson 0.67 0.41 1.53 0.1 negative binomial 0.15 0.15 1.26 0.11 Table 4: Comparison of UVP for recovered barycenters in our subset posterior aggregation task. 5.3 COLOR PALETTE AVERAGING For qualitative study, we apply our method to aggregating color palettes of images. For an RGB image I, its color palette is defined by the discrete uniform distribution P(I) of all its pixels [0, 1]3. For 3 images {In} we compute the barycenter P of each color palette Pn = P(In) w.r.t. uniform weights αn = 1 3. We apply each computed potential ψ n pixel-wise to In to obtain the pushforward image ψ n In. These pushforward images should be close to the barycenter P of {Pn}. (a) Original images {In}. (b) Color palettes {Pn} of original images. (c) Images with averaged color palette { ψ n In}. (d) Barycenter palettes { ψ n Pn}. Figure 2: Results of our method applied to averaging color palettes of images. The results are provided in Figure 2. Note that the image ψ 1 I1 inherits certain attributes of images I2 and I3: the sky becomes bluer and the trees becomes greener. On the other hand, the sunlight in images ψ 2 I2, ψ 3 I3 has acquired an orange tint, thanks to the dominance of orange in I1. ACKNOWLEDGMENTS The Skoltech Advanced Data Analytics in Science and Engineering Group acknowledges the support of Russian Foundation for Basic Research grant 20-01-00203, Skoltech-MIT NGP initiative and thanks the Skoltech CDISE HPC Zhores cluster staff for computing cluster provision. The MIT Geometric Data Processing group acknowledges the generous support of Army Research Office grant W911NF2010168, of Air Force Office of Scientific Research award FA9550-19-1-031, of National Science Foundation grant IIS-1838071, from the CSAIL Systems that Learn program, from the MIT IBM Watson AI Laboratory, from the Toyota CSAIL Joint Research Center, from a gift from Adobe Systems, from an MIT.nano Immersion Lab/NCSOFT Gaming Program seed grant, and from the Skoltech MIT Next Generation Program. 2http://archive.ics.uci.edu/ml/datasets/Bike+Sharing+Dataset Published as a conference paper at ICLR 2021 Martial Agueh and Guillaume Carlier. Barycenters in the wasserstein space. SIAM Journal on Mathematical Analysis, 43(2):904 924, 2011. Pedro C Álvarez-Esteban, E Del Barrio, JA Cuesta-Albertos, and C Matrán. A fixed-point approach to barycenters in wasserstein space. Journal of Mathematical Analysis and Applications, 441(2): 744 762, 2016. Brandon Amos, Lei Xu, and J Zico Kolter. Input convex neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 146 155. JMLR. org, 2017. Yann Brenier. Polar factorization and monotone rearrangement of vector-valued functions. Communications on pure and applied mathematics, 44(4):375 417, 1991. Bob Carpenter, Andrew Gelman, Matthew D Hoffman, Daniel Lee, Ben Goodrich, Michael Betancourt, Marcus Brubaker, Jiqiang Guo, Peter Li, and Allen Riddell. Stan: A probabilistic programming language. Journal of statistical software, 76(1), 2017. Sinho Chewi, Tyler Maunu, Philippe Rigollet, and Austin J Stromme. Gradient descent algorithms for bures-wasserstein barycenters. ar Xiv preprint ar Xiv:2001.01700, 2020. Samuel Cohen, Michael Arbel, and Marc Peter Deisenroth. Estimating barycenters of measures in high dimensions. ar Xiv preprint ar Xiv:2007.07105, 2020. Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in neural information processing systems, pp. 2292 2300, 2013. Marco Cuturi and Arnaud Doucet. Fast computation of wasserstein barycenters. 2014. DC Dowson and BV Landau. The fréchet distance between multivariate normal distributions. Journal of multivariate analysis, 12(3):450 455, 1982. Pavel Dvurechenskii, Darina Dvinskikh, Alexander Gasnikov, Cesar Uribe, and Angelia Nedich. Decentralize and randomize: Faster algorithm for wasserstein barycenters. In Advances in Neural Information Processing Systems, pp. 10760 10770, 2018. Jiaojiao Fan, Amirhossein Taghvaei, and Yongxin Chen. Scalable computations of wasserstein barycenter via input convex neural networks. ar Xiv preprint ar Xiv:2007.04462, 2020. Werner Fenchel. On conjugate convex functions. Canadian Journal of Mathematics, 1(1):73 77, 1949. Aude Genevay, Marco Cuturi, Gabriel Peyré, and Francis Bach. Stochastic optimization for largescale optimal transport. In Advances in neural information processing systems, pp. 3440 3448, 2016. Sham Kakade, Shai Shalev-Shwartz, and Ambuj Tewari. On the duality of strong convexity and strong smoothness: Learning applications and matrix regularization. Unpublished Manuscript, http://ttic. uchicago. edu/shai/papers/Kakade Shalev Tewari09. pdf, 2(1), 2009. Leonid Kantorovitch. On the translocation of masses. Management Science, 5(1):1 4, 1958. Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. ar Xiv preprint ar Xiv:1412.6980, 2014. Alexander Korotin, Vage Egiazarian, Arip Asadulaev, Alexander Safin, and Evgeny Burnaev. Wasserstein-2 generative networks. ar Xiv preprint ar Xiv:1909.13082, 2019a. Alexander Korotin, Vladimir V yugin, and Evgeny Burnaev. Integral mixability: a tool for efficient online aggregation of functional and probabilistic forecasts. ar Xiv preprint ar Xiv:1912.07048, 2019b. Lingxiao Li, Aude Genevay, Mikhail Yurochkin, and Justin Solomon. Continuous regularized wasserstein barycenters. ar Xiv preprint ar Xiv:2008.12534, 2020. Published as a conference paper at ICLR 2021 Ashok Vardhan Makkuva, Amirhossein Taghvaei, Sewoong Oh, and Jason D Lee. Optimal transport mapping via input convex neural networks. ar Xiv preprint ar Xiv:1908.10962, 2019. Robert J Mc Cann et al. Existence and uniqueness of monotone measure-preserving maps. Duke Mathematical Journal, 80(2):309 324, 1995. Youssef Mroueh. Wasserstein style transfer. ar Xiv preprint ar Xiv:1905.12828, 2019. Barak A Pearlmutter. Fast exact multiplication by the hessian. Neural computation, 6(1):147 160, 1994. Gabriel Peyré, Marco Cuturi, et al. Computational optimal transport. Foundations and Trends in Machine Learning, 11(5-6):355 607, 2019. Julien Rabin, Gabriel Peyré, Julie Delon, and Marc Bernot. Wasserstein barycenter and its application to texture mixing. In International Conference on Scale Space and Variational Methods in Computer Vision, pp. 435 446. Springer, 2011. Julien Rabin, Sira Ferradans, and Nicolas Papadakis. Adaptive color transfer with relaxed optimal transport. In 2014 IEEE International Conference on Image Processing (ICIP), pp. 4852 4856. IEEE, 2014. Vivien Seguy, Bharath Bhushan Damodaran, Rémi Flamary, Nicolas Courty, Antoine Rolet, and Mathieu Blondel. Large-scale optimal transport and mapping estimation. ar Xiv preprint ar Xiv:1711.02283, 2017. Justin Solomon, Fernando De Goes, Gabriel Peyré, Marco Cuturi, Adrian Butscher, Andy Nguyen, Tao Du, and Leonidas Guibas. Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4):1 11, 2015. Sanvesh Srivastava, Volkan Cevher, Quoc Dinh, and David Dunson. Wasp: Scalable bayes via barycenters of subset posteriors. In Artificial Intelligence and Statistics, pp. 912 920, 2015. Sanvesh Srivastava, Cheng Li, and David B Dunson. Scalable bayes via barycenter in wasserstein space. The Journal of Machine Learning Research, 19(1):312 346, 2018. Matthew Staib, Sebastian Claici, Justin M Solomon, and Stefanie Jegelka. Parallel streaming Wasserstein barycenters. In Advances in Neural Information Processing Systems, pp. 2647 2658, 2017. Cédric Villani. Topics in optimal transportation. Number 58. American Mathematical Soc., 2003. Jonathan Weed, Francis Bach, et al. Sharp asymptotic and finite-sample rates of convergence of empirical measures in wasserstein distance. Bernoulli, 25(4A):2620 2648, 2019. Yujia Xie, Minshuo Chen, Haoming Jiang, Tuo Zhao, and Hongyuan Zha. On scalable and efficient computation of large scale optimal transport. volume 97 of Proceedings of Machine Learning Research, pp. 6882 6892, Long Beach, California, USA, 09 15 Jun 2019. PMLR. URL http: //proceedings.mlr.press/v97/xie19a.html. Published as a conference paper at ICLR 2021 A THE ALGORITHM The numerical procedure for solving our final objective (14) is given below. Algorithm 1: Numerical Procedure for Optimizing Multiple Correlations (14) Input :Distributions P1, . . . , PN with sample access; Weights α1, . . . , αN 0 with PN n=1 αn = 1; Regularization distribution b P given by a sampler; Congruence regularizer coefficient τ 1; Balancing coefficient γ [0, 1]; Cycle-consistency regularizer coefficient λ > 0; 2N ICNNs {ψθn, ψωn}; Batch size K > 0; for t = 1, 2, . . . do 1. Sample batches Xn Pn for all n = 1, . . . , N; 2. Compute the pushforwards Yn = ψθn Xn for all n = 1, . . . , N; 3. Sample batch Y0 b P; 4. Compute the Monte-Carlo estimate of the congruence regularizer: LCongruence := 1 n =1 αn ψωn (y) y 2 where γ0 = γ and γn = αn (1 γ) for n = 1, 2, . . . , N; 5. Compute the Monte-Carlo estimate of the cycle-consistency regularizer: LCycle := 1 x Xn ψωn ψθn(x) x 2 2 6. Compute the Monte-Carlo estimate of multiple correlations: LMulti Corr := x, ψθn(x) ψωn( ψθn(x))] ; 7. Compute the total loss: LTotal := LMulti Corr + λ LCycle + τ LCongruence; 8. Perform a gradient step over {θn, ωn} by using LTotal {θn,ωn}; end Parametrization of the potentials. To parametrize potentials {ψθn, ψωn}, we use Dense ICNN (dense input convex neural network) with quadratic skip connections; see (Korotin et al., 2019a, Appendix B.2). As an initialization step, we pre-train the potentials to satisfy 2 and ψωn(y) y 2 Such pre-training provides a good start for the networks: each ψθn is approximately conjugate to the corresponding ψωn. On the other hand, the initial networks {ψθn} are approximate congruent according to (5). Computational Complexity. For a single training iteration, the time complexity of both forward (evaluation) and backward (computing the gradient with respect to the parameters) passes through the objective function (14) is O(NT). Here N is the number of input distributions and T is the time taken by evaluating each individual potential (parameterized as a neural network) on a batch of points sampled from either Pn or b P. This claim follows from the well-known fact that gradient evaluation θhθ(x) of hθ : RD R, when parameterized as a neural network, requires time proportional Published as a conference paper at ICLR 2021 to the size of the computational graph. Hence, gradient computation requires computational time proportional to the time for evaluating the function hθ(x) itself. The same holds when computing the derivative with respect to x. Then, for instance, computing the term ψ n ψ n(x) in (14) takes O(T) time. The gradient of this term with respect to θ also takes O(T) time: Hessian-vector products that appear can be calculated in O(T) time using the famous Hessian trick, see Pearlmutter (1994). In practice, we compute all the gradients using automatic differentiation. We empirically measured that for our Dense ICNN potentials, the computation of their gradient w.r.t. input x, i.e., ψ (x), requires roughly 3-4x more time than the computation of ψ (x). In this section, we prove our main Theorems 4.1 and 4.2. We use L2(RD RD, µ) to denote the Hilbert space of functions f : RD RD with integrable square w.r.t. a probability measure µ. The corresponding inner product for f1, f2 L2(RD RD, µ) is denoted by f1, f2 µ def = Z RD f1(x), f2(x) dµ(x), where f1(x), f2(x) is the Euclidean dot product. We use µ = p , µ to denote the norm induced by the inner product in L2(RD RD, µ). We also recall a useful property of lower semi-continuous convex function ψ : RD R: ψ(x) = arg max y RD y, x ψ(y) , (19) which follows from the fact that ˆy = arg max y RD y, x ψ(y) x ψ(ˆy) = 0. We begin with the proof of Theorem 4.1. Proof. We consider the difference between the estimated correlations and true ones: RD ψn(x)d Pn(x) RD ψ n(x)d Pn(x) = ψn(x), x ψn ψn(x)) d Pn(x) ψ n(x), x ψ n ψ n(x)) d Pn(x), (20) where we twice use (19) for f = ψn and f = ψ n. We note that RD ψ n(x), x d Pn(x) = RD y, ψ n(y) d P(y) = n=1 αn ψ n(y) d P(y) = Z RD y, y d P(y) = id RD 2 where we use of change-of-variable formula for ψ n Pn = P and (5). Analogously, RD ψ n ψ n(x))d Pn(x) = RD ψ n y)d P(y) = n=1 αnψ n y)d P(y) = Z 2 d P(y) = 1 Published as a conference paper at ICLR 2021 Since each ψn is B-smooth, we conclude that ψn is 1 B-strongly convex, see (Kakade et al., 2009). Thus, we have ψn ψ n(x))) ψn ψn(x))) + ψ n ψ n(x) | {z } =x , ψ n(x) ψn(x) + 1 2B ψ n(x) ψn(x) 2 = ψn ψn(x))) + x, ψ n(x) ψn(x) + 1 2B ψ n(x) ψn(x) 2, (23) or equivalently ψn ψn(x))) ψn ψ n(x)))+ x, ψ n(x) ψn(x) + 1 2B ψ n(x) ψn(x) 2. (24) We integrate (24) w.r.t. Pn and sum over n = 1, 2, . . . , N with weights αn: RD ψn ψn(x))d Pn(x) RD ψn ψ n(x))d Pn(x) + n=1 αn x, ψ n(x) Pn n=1 αn x, ψn(x) Pn + n=1 αn 1 2B ψ n(x) ψn(x) 2 Pn = n=1 αnψn y)d P(y) + n=1 αn x, ψ n(x) Pn n=1 αn x, ψn(x) Pn + n=1 αn 1 2B ψ n(x) ψn(x) 2 Pn. (25) We note that n=1 αnψn y)d P(y) = Z n=1 αnψn y) d P(y) Z n=1 αnψn y) d P(y) 1 Now we substitute (25), (26), (21) and (22) into (20) to obtain (9). Next, we prove Theorem 4.2. Proof. Since ψ n is β strongly convex, its conjugate ψ n is 1 β -smooth, i.e. has 1 β -Lipschitz gradient ψ n (Kakade et al., 2009). Thus, for all x, x RD: ψ n(x) ψ n(x ) 2 ( 1 β )2 x x 2. We substitute x = ψ n ψ n(y) = ψ n 1 ψ n(y) and obtain: ψ n(x) ψ n(x) 2 ( 1 β )2 x ψ n ψ n(x) 2. (27) Since the function ψ n is B -smooth, we have for all x RD: ψ n( ψ n(x)) ψ n( ψ n(x)) + ψ n ψ n(x) | {z } =x , ψ n(x) ψ n(x) + B 2 ψ n(x) ψ n(x) 2, Published as a conference paper at ICLR 2021 that is equivalent to: x, ψ n(x) ψ n( ψ n(x)) x, ψ n(x) ψ n( ψ n(x)) | {z } 2 ψ n(x) ψ n(x) 2. (28) We combine (28) with (27) to obtain x, ψ n(x) ψ n( ψ n(x)) ψ n(x) B 2(β )2 id RD ψ n ψ n 2. (29) For every n = 1, 2, . . . , N we integrate (29) w.r.t. Pn and sum up the corresponding cycle-consistency regularization term: Z x, ψ n(x) ψ n( ψ n(x))]d Pn(x) + λ ψ n ψ n id RD 2 Pn Z RD ψ (x)d Pn(x) + λ B 2(β )2 ψ n ψ n id RD 2 Pn | {z } RPn 2 (ψ n,ψ n) We sum (30) for n = 1, 2, . . . , N w.r.t. weights αn to obtain: x, ψ n(x) ψ n( ψ n(x))]d Pn(x) + λ n=1 αn RPn 2 (ψ n, ψ n) RD ψ (x)d Pn(x) Multi Corr({αn,Pn}|{ψ n}) 2(β )2 RPn 2 (ψ n, ψ n). We add τ Rb P 1({ψ n}) to both sides of (31) to get Multi Corr {αn, Pn} | {ψ n}, {ψ n}; τ, b P, λ Multi Corr({αn, Pn} | {ψ n}) + τ R b P 1({ψ n}) + 2(β )2 RPn 2 (ψ n, ψ n). (31) We substract Multi Corr({αn, Pn} | {ψ n}) from both sides and use Theorem 4.1 to obtain αnψ n(y) y 2 2 d P(y) + β n=1 αn ψ n(x) ψ n(x) 2 Pn + (32) τ R b P 1({ψ n}) + 2(β )2 RPn 2 (ψ n, ψ n) (33) 2(β )2 RPn 2 (ψ n, ψ n) + β n=1 αn ψ n(x) ψ n(x) 2 Pn. (34) In transition from (33) to (34), we explot the fact that the sum of the first term of (32) with the regularizer τ Rb P 1({ψ n}). Since λ > B 2(β )2 , from (34) we immediately conclude 0; i.e., the multiple correlations upper bound (15) holds true. On the other hand, for every n = 1, 2, . . . , N we have ψ n(x) ψ n(x) 2 Pn 2 αnβ and ψ n ψ n id RD 2 Pn 2 αn (λ B 2(β )2 ) . (35) We combine the second part of (35) with (27) integrated w.r.t. Pn: ψ n ψ n 2 Pn 2 αn (λ(β )2 B Published as a conference paper at ICLR 2021 Finally, we use the triangle inequality for Pn and conclude ψ n ψ n Pn ψ n ψ n Pn + ψ n ψ n Pn r W2 2( ψ n Pn, P) ψ n ψ n 2 Pn 2 where the first inequality follows from (Korotin et al., 2019a, Lemma A.2). C EXPERIMENTAL DETAILS AND EXTRA RESULTS In this section, we provide experimental details and additional results. In Subsection C.1, we demonstrate qualitative results of computed barycenters in the 2-dimensional space. In Subsection C.2, we discuss used metrics in more detail. In Subsection C.4, we list the used hyperparameters of our method (CW2B) and methods [SCW2B], [CRWB]. C.1 ADDITIONAL TOY EXPERIMENTS IN 2D We provide additional qualitative examples of computed barycenters of probability measures on R2. In Figure 3, we consider the location-scatter family F(P0) with P0 = Uniform[ 3]D. In principle, all the methods capture the true barycenter. However, the generated distribution g S of [SCW2B] (Figure 3c) provides samples that lies outside of the actual barycenter s support (Figure 3b). Also, in [CRWB] method, one of the potentials pushforward measure (top-right in Figure 3e) has visual artifacts. (a) Input distributions {Pn} (b) True barycenter P (c) SCW2B, generated distribution g S (d) SCW2B, distributions ψ n Pn (e) CRWB, distributions ψ n Pn (f) CW2B, distributions ψ n Pn Figure 3: Barycenter of a random location-scatter population computed by different methods. In Figure 4, we consider the Gaussian Mixture example by (Fan et al., 2020). The barycenter computed by [SCW2B] method (Figure 4b) suffers from the behavior similar to mode collapse. Published as a conference paper at ICLR 2021 (a) Inputs {Pn} (b) SCW2B g S (c) SCW2B ψ n Pn (d) CRWB ψ n Pn (e) CW2B ψ n Pn Figure 4: Barycenter of a two 2D Gaussian mixtures. C.2 METRICS The unexplained variance percentage (UVP) (introduced in Section 5) is a natural and straightforward metric to assess the quality of the computed barycenter. However, it is difficult to compute in high dimensions: it requires computation of the Wasserstein-2 distance. Thus, we use different but highly related metrics L2-UVP and BW2 2-UVP. To access the quality of the recovered potentials {ψ n} we use L2-UVP defined in (17). L2-UVP compares not just pushforward distribution ψ n Pn with the barycenter P, but also the resulting transport map with the optimal transport map ψ n. It bounds UVP( ψ n Pn) from above, thanks to (Korotin et al., 2019a, Lemma A.2). Besides, L2-UVP naturally admits unbiased Monte Carlo estimates using random samples from Pn. For measure-based optimization method, we also evaluate the quality of the generated measure g S using Bures-Wasserstein UVP defined in (18). For measures P, Q whose covariance matrices are not degenerate, BW2 2 is given by BW2 2(P, Q) = 1 2 µP µQ 2 + 1 2 Tr ΣP + 1 2 Tr ΣQ Tr(Σ 1 2 P ) 1 2 . Bures-Wasserstein metric compares P, Q by considering only their first and second moments. It is known that BW2 2(P, Q) is a lower bound for W2 2(P, Q), see (Dowson & Landau, 1982). Thus, we have BW2 2-UVP(g S) UVP(g S). In practice, to compute BW2 2-UVP(g S), we estimate means and covariance matrices of distributions by using 105 random samples. C.3 CYCLE CONSISTENCY AND CONGRUENCE IN PRACTICE To assess the effect of the regularization of cycle consistency and the congruence condition in practice, we run the following sanity checks. For cycle consistency, for each input distribution Pn we estimate (by drawing samples from Pn) the value ψ n ψ n(x) x 2 Pn/Var(Pn). This metric can be viewed as an analog of the L2-UVP that we used for assessing the resulting transport maps. In all the experiments, this value does not exceed 2%, which means that cycle consistency and hence conjugacy are satisfied well. For the congruence condition, we need to check that PN n=1 αnψ n(x) = x 2/2. However, we do not know any straightforward metric to check this exact condition that is scaled properly by the variance of the distributions. Thus, we propose to use an alternative metric to check a slightly weaker condition on gradients, e.g., that PN n=1 αn ψ n(x) = x. This is weaker due to the ambiguity of the additive constants. For this we can compute PN n=1 αn ψ n(x) x 2 P/Var(P), where the denominator is Published as a conference paper at ICLR 2021 the variance of the true barycenter. We computed this metric and found that it is also less than 2% in all the cases, which means that congruence condition is mostly satisfied. C.4 TRAINING HYPERPARAMETERS The code is written using the Py Torch framework. The networks are trained on a single GTX 1080Ti. C.4.1 WASSERSTEIN-2 CONTINUOUS BARYCENTERS (CW2B, OUR METHOD) Regularization. We use τ = 5 and ˆP = N(0, ID) in our congruence regularizer τ RˆP 1. We use λ = 10 for the cycle regularization λ RPn 2 for all n = 1, 2, . . . , N. Neural Networks (Potentials). To approximate potentials {ψ n, ψ n} in dimension D, we use Dense ICNN[2; max(64, 2D), max(64, 2D), max(32, D)] with CELU activation function. Dense ICNN is an input-convex dense architecture with additional convex quadratic skip connections. Here 2 is the rank of each input-quadratic skip-connection s Hessian matrix. Each following number max( , ) represents the size of a hidden dense layer in the sequantial part of the network. For detailed discussion of the architecture see (Korotin et al., 2019a, Section B.2). Training process. We perform training according to Algorithm 1 of Appendix A. We set batch size K = 1024 and balancing coefficient γ = 0.2. We use Adam optimizer by (Kingma & Ba, 2014) with a fixed learning rate 10 3. The total number of iterations is set to 50000. C.4.2 SCALABLE COMPUTATION OF WASSERSTEIN BARYCENTERS (SCW2B) Generator Neural Network. For the input noise distribution of the generative model we use S = N(0, ID). For the generative network g : RD RD we use a fully-connected sequential Re LU network with hidden layer sizes [max(100, 2D), max(100, 2D), max(100, 2D)]. Before the main optimization, we pre-train the network to satisfy g(z) z for all z RD. This has been empirically verified as a better option than random initialization of network s weights. Neural Networks (Potentials). We used exactly the same networks as in Subsection C.4.1. Training process. We perform training according to the min-max-min procedure described by (Fan et al., 2020, Algorithm 1). The batch size is set to 1024. We use Adam optimizer by (Kingma & Ba, 2014) with fixed learning rate 10 3 for potentials and 10 4 for generative network g. The number of iterations of the outer cycle (min-max-min) number of iterations is set to 15000. Following (Fan et al., 2020), we use 10 iterations per the middle cycle (min-max-min) and 6 iterations per the inner cycle (min-max-min). C.4.3 CONTINUOUS REGULARIZED WASSERSTEIN BARYCENTERS (CRWB) Regularization. [CRWB] method uses regularization to keep the potentials conjugate. The authors impose entropy or L2 regularization w.r.t. some proposal measure ˆP; see (Li et al., 2020, Section 3) for more details. Following the source code provided by the authors, we use L2 regularization (empirically shown as a more stable option than entropic regularization). The regularization measure ˆP is set to be the uniform measure on a box containing the support of all the source distributions, estimated by sampling. The regularization parameter ϵ is set to 10 4. Neural Networks (Potentials). To approximate potentials {ψ n, ψ n} in dimension D, we use fullyconnected sequential Re LU neural networks with layer sizes given by [max(128, 4D), max(128, 4D), max(128, 4D)]. We have also tried using Dense ICNN architecture, but did not experience any performance gain. Training process. We perform training according to (Li et al., 2020, Algorithm 1). We set batch size to 1024. We use Adam optimizer by (Kingma & Ba, 2014) with fixed learning rate 10 3. The total number of iterations is set to 50000.