# malts_matching_after_learning_to_stretch__e189b4e4.pdf Journal of Machine Learning Research 23 (2022) 1-42 Submitted 1/21; Revised 7/22; Published 8/22 MALTS: Matching After Learning to Stretch Harsh Parikh harsh.parikh@duke.edu Department of Computer Science Duke University Durham, NC 27708-0129, USA. Cynthia Rudin cynthia@cs.duke.edu Department of Computer Science Duke University Durham, NC 27708-0129, USA. Alexander Volfovsky alexander.volfovsky@duke.edu Department of Statistical Science Duke University Durham, NC 27710, USA. Editor: Russ Greiner We introduce a flexible framework that produces high-quality almost-exact matches for causal inference. Most prior work in matching uses ad-hoc distance metrics, often leading to poor quality matches, particularly when there are irrelevant covariates. In this work, we learn an interpretable distance metric for matching, which leads to substantially higher quality matches. The learned distance metric stretches the covariate space according to each covariate s contribution to outcome prediction: this stretching means that mismatches on important covariates carry a larger penalty than mismatches on irrelevant covariates. Our ability to learn flexible distance metrics leads to matches that are interpretable and useful for the estimation of conditional average treatment effects. Keywords: causal inference, matching, nearest neighbors, distance metric learning 1. Introduction Matching methods are used throughout the social and health sciences to make causal conclusions where access to randomized trials is scarce but observational data are widely available. Matching methods construct sets of similar individuals, some of whom select into treatment and some of whom select into control, allowing for direct comparison of outcomes between the samples from these populations. These methods are particularly interpretable since they allow fine-grained troubleshooting of the data. For instance, examining a matched group of patients through chart review of their medical data and doctors notes may allow an analyst to determine whether the matched groups are indeed trustworthy, and if not, determine what other factors should be included in the analysis. Having high-quality matches also allows the user to estimate nonlinear treatment effects with lower bias than parametric approaches. As a concrete example of the importance of match group quality, Table 1 presents a series of matched groups from the Lalonde dataset (La Londe 1986, Dehejia and Wahba 2022 Harsh Parikh, Cynthia Rudin and Alexander Volfovsky. License: CC-BY 4.0, see https://creativecommons.org/licenses/by/4.0/. Attribution requirements are provided at http://jmlr.org/papers/v23/21-0053.html. Parikh, Rudin and Volfovsky 1999). A simple visual inspection of the matched groups produced by standard-bearer methods like propensity score matching and prognostic score matching reveals that the units being considered similar by these methods are not similar on underlying covariates. On the other hand, the matches generated by our proposed method are qualitatively (and quantitatively) better. The quality of the matches is our main consideration in this work. Table 1: Example control units in a matched group for a treated unit using (a) our approach (MALTS), (b) prognostic score (Hansen 2008), and (c) propensity score matching (Rosenbaum and Rubin 1983) for a query unit in the Lalonde dataset (top rows). Our method matched closely on covariates age, education, whether the person had an academic degree, and income in 1975. In contrast, prognostic and propensity scores did not match closely on these factors. Treatment Covariates Outcome Unit ID Treated Age Education Black Hispanic Married No-Degree Income-1975 Income-1978 Query: 1 Yes 22 9 No Yes No Yes $0 $3596 (a) Our Approach (MALTS) 330 No 22 8 No Yes No Yes $0 $9921 299 No 22 9 Yes No No Yes $0 $0 416 No 22 9 Yes No No Yes $0 $12898 (b) Prognostic Scores 338 No 44 9 Yes No No Yes $0 $9722 340 No 22 12 Yes No No No $532 $1333 355 No 18 10 No Yes No Yes $0 $1859 (c) Propensity Scores 451 No 22 8 Yes No No Yes $0 $1391 330 No 22 8 No Yes No Yes $0 $9921 407 No 20 12 Yes No No No $1371 $20893 Typically, matching methods place units that are close together into the same matched group, where closeness is measured in terms of a pre-defined distance (e.g., exact, coarsened exact, Euclidean, etc.), while maintaining balance constraints between treatment and control units. Despite its merits, this classical paradigm has flaws, namely that it relies heavily on a prespecified distance metric. The distance metric cannot be determined without an understanding of the importance of the variables; for instance, the quality of matches for any prespecified distance that weighs all covariates equally will degrade as the number of irrelevant covariates increases. This is true irrespective of the matching methodology employed. This issue has previously been referred to as the toenail problem (Wang et al. 2021, Dieng et al. 2019), where the inclusion of irrelevant covariates (like toenail length ) with nonzero weights can worsen the metric for matching. A related concern is that the covariates may be scaled differently, where a given distance along one covariate has a different impact than the same distance along a different covariate; in this case, if the scaling or weights on the covariates are chosen poorly, the total distance metric can inadvertently be determined by less relevant covariates, again leading to lower quality matches. Ideally, the distance metric would focus on important covariates that significantly contribute to the outcome, so that after matching, treatment effect estimates computed using the matched groups would be accurate. If the researcher knows how to choose the distance metric so that it yields accurate treatment effect estimates, it would solve the problem. However, there is no reason to believe that this is achievable in complex high-dimensional data settings. Producing high dimensional functions to characterize data is a task at which humans are not naturally adept. In this work, we propose a framework for matching where an interpretable distance measure between matched units is learned from a training set. As long as the distance metric generalizes from the training set to the full sample, we are able to compute highquality matches and accurate estimates of conditional average treatment effects (CATEs) within the matched groups. One can use any form of distance metric to train, and in this work, we focus on exact matching for discrete variables and generalized Mahalanobis distances for continuous variables. By definition, the generalized Mahalanobis distance is determined by a matrix. If the matrix is diagonal, the distance calculation represents a stretch for each covariate. Irrelevant covariates will be compressed so that their values are always effectively zero. Highly relevant covariates will be stretched so that for two units to be considered a match, they must have very similar values for those covariates. In this way, diagonal matrices lead to very interpretable distance metrics. If the Mahalanobis distance matrix is not constrained to be diagonal, then it induces a stretch and rotation, leading to more flexible but less interpretable notions of distance. The new framework is called Learning-to-Match, and the algorithm introduced in this work is called Matching After Learning to Stretch (MALTS). Figure 1 shows the main steps of MALTS, which are: divide the data into training and estimation sets, learn the distance metric on the training set, use the learned distance metric to perform nearest neigbhor matching on estimation set, and use those matched groups to estimate conditional average treatment effects (CATEs). We tested MALTS against several other matching methods in simulation studies (Section 6), where ground truth CATEs are known. In these experiments, MALTS consistently achieves substantially better results than other matching methods including Genmatch, propensity score matching, and prognostic score matching for estimating CATEs. Even though our method is heavily constrained to produce interpretable matches, it performs at the same level as non-matching methods that are designed to fit extremely flexible but uninterpretable models directly to the response surface. In Section 3, we introduce the learning-to-match framework and show that under a choice of smooth distance metric (Definition 1) we can estimate conditional average treatment effects accurately with high probability. Section 4 discusses MALTS optimization set up and training procedure that learns a smooth distance metric. In Section 5, we prove that the distance metric learned by MALTS is multi-robust (Definition 3) and generalizable (Definition 5). Thus, the distance metric estimated by MALTS framework facilitates the correct estimates of CATEs under SUTVA and positivity assumptions. 2. Related work Since the 1970 s, the causal inference literature on matching methods has been concentrated on dimension reduction techniques (e.g., Rubin 1973a,b, 1976, Cochran and Rubin 1973). In this literature, the leading approach for dimension reduction uses the propensity score, which is the conditional probability of treatment given covariate information. Propensity score methods are designed for calculating average treatment effects (as opposed to conditional average treatment effects) and do not produce exact or almost-exact matches. When treatment is binary, they project data onto one dimension, and closeness of units in propen- Parikh, Rudin and Volfovsky Figure 1: Schematic drawing of MALTS algorithm. The algorithm splits the data into random subsets and uses one of the subsets (training set) to learn a distance metric. It performs matching on the rest of the units (estimation set) using the learned distance metric to produce tightly matched groups and estimate conditional average treatment effects. sity score does not imply their closeness in covariate space. As a result, the matches cannot directly be used for estimating heterogeneous treatment effects. Other causal inference methods have been studied in the literature (Gu and Rosenbaum 1993, Imbens 2004), but almost all of them suffer from at least one of four possible problems: using a black box model that is uninterpretable (i.e., almost all machine learning methods), having a distance metric that is predefined (rather than learned), computational inefficiency, or not being applicable to CATE estimation (as we discussed with propensity scores). These issues cause the vast majority of matching methods to be ineffective in producing high quality interpretable CATE estimates. Regression methods can be used for CATE estimation, but only when the regression method is correctly specified or in the case of doubly robust estimation (e.g., Farrell 2015), either the propensity model or the outcome model needs to be correctly specified. Machine learning approaches generalize regression approaches and can create models that are extremely flexible and predict outcomes accurately for both treatment and control groups (Hill 2011, Chernozhukov et al. 2018, Hahn et al. 2020). However, complicated regression methods lose the interpretability inherent to almost-exact matches and are difficult to troubleshoot and trust. In practice, MALTS performs similarly to (or better than) several machine learning methods in our experiments, despite being restricted to interpretable almost-exact matches with an interpretable distance metric. A flexible setup for producing high-quality matches is provided by the optimal matching literature (Rosenbaum 2017). These are built on network flow algorithms and integer programming to produce matches that are constrained in user-defined ways (Zubizarreta 2012, Zubizarreta et al. 2014, Keele and Zubizarreta 2017, Resa and Zubizarreta 2016, Kallus 2017, Morucci et al. 2022). In all of these approaches, the user defines the distance metric (rather than learning it from data), potentially leading to poor quality matched groups. An alternative to optimal matching is coarsened exact matching (CEM, Iacus et al. 2012), an approach that requires users to specify explicit bins for all covariates on which to con- struct matches. This requires users to know in advance that the outcomes are insensitive to movements within many high-dimensional bins, which is essentially equivalent to the user knowing the answer to the problem we investigate in this work. Large amounts of user choice to define these bins can also lead to unintentional user bias. By learning the stretching rather than asking the user to define it as in CEM, this bias is potentially reduced. Zhao (2004) and Imbens (2004) discuss the choice of distance metric for matching. The approach by Zhao (2004) depends on the correlations between treatment choice, outcome and covariates. However, this approach assumes a model for the relationship between the outcome and covariates, or the treatment choice and covariates. Hence, under model misspecification, the estimator may not be consistent. MALTS learns a distance metric without any model assumptions. The present work builds on work of Wang et al. (2021), Dieng et al. (2019) where a discrete distance metric is learned by considering the prediction quality of the covariate sets. That work does not pertain to continuous covariates, whereas ours does. There is substantial work on learning distance metrics (though not for causal inference, e.g., Goldberger et al. 2005, Weinberger et al. 2006, Weinberger and Saul 2009), where the goal is to learn a distance metric in latent space to separate different classes of data in supervised learning, often with a margin. This is different from our goal of matching for causal inference, but some of our proofs were inspired by this work in supervised learning. A sister work, developed in parallel, is that of Morucci et al. (2020), which learns adaptively-sized hyperboxes as matched groups. MALTS was previously used on the ACIC 2018 Causal Inference Challenge Data (see Parikh et al. 2019). An extension of MALTS for multi-level treatments has been used to study the effect of seizures on the discharge status of critical ill patients (see Parikh et al. 2022). 3. Learning-to-Match Framework Within this framework, we perform treatment effect estimation using following three stages: 1) learning a distance metric, 2) matching samples, and 3) estimating CATEs. We denote the p dimensional covariate vector space as X Rp and the unidimensional outcome space by Y R. Let T be a finite label set of treatment indicators (in this paper we consider only the binary case). Let Z = X Y T such that z = (x, y, t) Z means that x X, y Y and t T . Let µ be an unknown probability distribution over Z such that z Z, µ(z) > 0. We assume that X is a compact convex space with respect to 2, thus there exists a constant Cx such that x 2 Cx. Also, |y| Cy. A distance metric is a symmetric, positive definite function with two arguments from X such that d : X X R+. A distance metric must obey the triangle inequality. Let Sn denote a set of n observed units {s1, ..., sn} drawn i.i.d. from µ such that i, si Z. We parameterize d with parameter M( ), explicitly calling it d M, and let M(Sn) denote the parameter learned using MALTS methodology which is described in Section 4. For ease of notation, we will denote the observed sample of treated units as S(T) n := {s(T) i = (xi, yi, ti) | s(T) i Sn and ti = T} and the observed sample of control units as S(C) n := {s(C) i = (xi, yi, ti) | s(C) i Sn and ti = C}. We assume no unobserved confounders and standard ignorability assumptions, i.e., i, (Y (T) i , Y (C) i ) Ti | (Xi = xi) (Rubin 2005) where Y (T) i and Y (C) i are potential outcomes for unit i under treatments (T) and (C) respectively, Ti is unit i s treatment Parikh, Rudin and Volfovsky choice and Xi corresponds to the vector of covariates for unit i. For each individual unit si = (xi, yi, ti) Z we define its conditional average treatment effect (or individualized treatment effect) as the difference of potential outcomes of unit i under the treatment and control, τ(xi) = E h Y (T) i Y (C) i |Xi = xi i = E h Y (T) i |Xi = xi i E h Y (C) i |Xi = xi i . We use the b Y (t) xi to refer to the estimated conditional average potential outcome, E h Y (t) i |Xi = xi i , for treatment t T and covariate level xi X. bτ(xi) refers to the estimated conditional average treatment effect for covariate value xi. Our goal is to minimize the expected loss between estimated treatment effects bτ(x) and true treatment effects τ(x) across target population µ(z) (this can either be a finite or super-population). Let the population expected loss be: E [ℓ(bτ(x), τ(x))] = Z ℓ(bτ(x), τ(x))dµ = Z ℓ( ˆY (T) x ˆY (C) x , E[Y (T) Y (C)|X = x])dµ. For a finite random i.i.d. sample {si = (xi, yi, ti)}n i=1 from the distribution µ, the finite sample version of the average loss can be written as i=1 ℓ ˆY (T) xi ˆY (C) xi , E h Y (T) i |Xi = xi i E h Y (C) i |Xi = xi i . However, we do not observed true values of E h Y (T) i |Xi = xi i and E h Y (C) i |Xi = xi i . Instead, we could estimate the upper bound of sample average loss as i=1 tiℓ ˆY (T) xi , yi + (1 ti)ℓ ˆY (C) xi , yi . Here, we use can yi for ti = 1 as the unbiased estimate of E h Y (T) i |Xi = xi i and similar for ti = 0. For a unit si, we estimate the conditional average potential outcomes, ˆY (T) xi and ˆY (C) xi , using the treated and control units outcomes in the constructed matched group using the observed data. The matched group MG of unit si for treatment t under the distance metric d M on covariate space is defined as a set of K nearest neighbors of si from set S(t ) n = {sk|tk = t , sk Sn}. MG(si, d M, S(t ) n , K) = KNNSn M (xi, t ) := sk : X 1 d M(xl, xi) < d M(xk, xi) < K . (1) We allow reuse of units in multiple matched groups. Thus for a chosen estimator φ, ˆY (t ) xi = φ MG(si, d M, S(t ) n , K) (2) where K is the size of the matched group MG(si, d M, S(t ) n , K). A simple example of φ is the mean estimator, i.e. φ MG(si, d M, S(t ) n , K) = 1 K P k MG(si,d M,S(t ) n ,K) yk. However, one can choose the estimator to be a weighted mean, linear regression or a non-parametric model like random-forest, within the matched group. Our framework performs honest causal inference by learning a distance metric from a separate training set of data (not the estimation data considered in the averages above), and we denote this training set by Str. To learn d M, we minimize the following: M(Str) arg min M si S(T ) tr yi ˆY (T) xi + P yi ˆY (C) xi i , where ˆY (C) xi and ˆY (T) xi are defined by Equations (1) and (2) including its dependence on the distance d M, which is parameterized by M, using the training data to create matched groups. Once M(Str) is learned from the training set, it is used for matching (and estimation) on the estimation data. 3.1 Smooth Distance Metric and Treatment Effect Estimation In this subsection, we discuss that if a distance metric is a smooth distance metric, then we can estimate the individualized treatment effect using a finite sample with high probability. First, let us define a smooth distance metric. Definition 1 (Smooth Distance Metric) d M : X X R+ is a smooth distance metric if there exists a monotonically increasing bounded function δd M( ) with zero intercept, such that zi, zj Z if ti = tj and d M(xi, xj) a then |E[Yi|Xi = xi, Ti = ti] E[Yj|Xj = xj, Tj = tj]| δd M(a). The concept of the smooth distance metric is analogous to commonly assumed Lipschitz continuity in the matching literature (Abadie and Imbens 2006). Note that because the range of Y is bounded, there always exists a choice of the function δM( ) such that a distance metric d M is smooth. This choice of δM( ) controls the quality of inference from the matching as we see in Theorem 1 below. Theorem 1 (Basic CATE Bound for Smooth Distance Metrics) Let {Sn} n=1 be a sequence of nested datasets, each of which includes n i.i.d. samples from µ(Z), n = 1.. . Given a smooth distance metric d M, covariate vector x, and α > 0, if there exists a small enough value of a and a large enough value of N such that K(t ) n (x) = {zk : d M(Xk, x) < a, Tk = t , zk Sn} is non-empty and α > 2δd M(a) for all n N and t T , then P{Yi}n i=1 µ(Yn) |ˆτ(x) τ(x)| α 4 exp Kn(x)(α 2 δd M(a))2 where ˆτ(x) is the estimated conditional average treatment effect using the matched sets K(1) n (x) and K(0) n (x), τ(x) is the true conditional average treatment effect, Kn(x) = mint |K(t ) n (x)|, and δd M(a) is the bound from Definition 1 (definition of smooth distance metric). Parikh, Rudin and Volfovsky Theorem 1 directly follows from Lemma 5 in the Appendix A which proves that for all t T and x X, we can estimate average conditional potential outcomes, E[Y (t )|X = x], correctly with high probability using nearest neighbor matching under any smooth distance metric, and Lemma 6 in Appendix A which proves that estimating average conditional potential outcomes correctly with high probability leads to estimating CATEs, τ, correctly with high probability. Our setup and Definition 1 are similar to one described by Kara et al. (2017). Our result in Lemma 5 proves the consistency for a uniform weighted nearest neighbor estimator where the weights are probability weights. The result is in congruence with consistency results by Stone (1977) and Jiang (2019); those works handled the special case where the weights are uniform probability weights instead of any probability weights. Note that matching using any type of stretch norm that induces a smooth distance metric, including Mahalanobis distance (or its special case with an identity covariance matrix, the L2 distance), to adjust for confounding produces consistent estimates of average treatment effects. Prognostic score (Hansen 2008) and other approaches that induce a smooth distance metric also produce consistent estimates of ATE. 4. Matching After Learning to Stretch (MALTS) MALTS performs weighted nearest neighbors matching, where the weights for the nearest neighbors can be learned by minimizing the following objective. This objective is simply the loss of the in-sample nearest neighbor estimator: W arg min f W sl S(T ) tr ,i =l l S(C) tr ,i =l + Reg(f W), (3) where Reg( ) is a regularization function. We let f Wi,l be a function of d M(xi, xl). For example, the f Wi,l can encode whether l belongs to i s K-nearest neighbors. Alternatively, they can encode soft KNN weights where f Wi,l e d M(xi,xl). Thus, the intuition is to learn W such that the in-sample nearest-neighbors estimator is as accurate as possible. As a reminder of our notation, we consider distance metric d M parameterized by a set of parameters M. We use Euclidean distances for continuous covariates, namely distances of the form Mxa Mxb 2 where M encodes the orientation of the data. In the past, M has been hard-coded rather than learned; an example in the causal inference literature is the classical Mahalanobis distance (M is fixed as the inverse covariance matrix for the observed covariates). This approach has been demonstrated to perform well in settings where all covariates are observed and the inferential target is the average treatment effect (Stuart 2010). We are interested instead in individualized treatment effects, and just as the choice of Euclidean norm in Mahalanobis distance matching depends on the estimand of interest, the stretch metric needs to be amended for this new estimand. We propose learning the parameters of a distance metric, M, directly from the observed data rather than setting it beforehand. The parameters of distance metric M can be learned such that W minimizes the objective function on the training set. In our framework, we can define approximate closeness differently for discrete covariates if desired. For continuous covariates, MALTS uses Euclidean distance, which is also a reasonable metric to use for binary data (e.g., Mahalanobis-distance-matching papers recommend converting unordered categorical variables to binary indicators, see Stuart 2010); however, there are benefits to using other metrics, such as weighted Hamming distances, for comparison among sets of binary covariates. To accommodate a combination of Euclidean and Hamming distances, we parameterize our distance metric in terms of two components: one is a learned weighted Euclidean distance for continuous covariates while the other is a learned weighted Hamming distance for discrete covariates as in the FLAME and DAME algorithms (Wang et al. 2021, Dieng et al. 2019). These components are separately parameterized by matrices Mc and Md respectively, M = [Mc, Md] (here c indicates continuous, and d indicates discrete ). Let a = (ac, ad) and b = (bc, bd) be the covariates for two individuals split into continuous and discrete pairs respectively. Operationalizing Equation (3): To perform the step called Distance Metric Learning in Figure 1 we propose the following form for the distance metric: d M(a, b) = d Mc(ac, bc) + d Md(ad, bd), where d Mc(ac, bc) = Mcac Mcbc 2, d Md(ad, bd) = j=0 M(j,j) d 1[a(j) d = b(j) d ], and 1[A] is the indicator that event A occurred. We thus perform learned Hamming distance matching on the discrete covariates and learned-Mahalanobis-distance matching for continuous covariates. MALTS performs an honest causal inference by splitting the observed sample dataset Sn into a training set Str (not for matching) and an estimation set Sest (for matching). We learn M(Str) using the training set Str such that in Equation (3), f Wi,l = e d M(xi,xl) P sk S(ti) tr e d M(xi,xk) and Reg(f W) = M F which defines MALTS main implemented optimization problem: M(Str) arg min M c M F + (C) Str (M) + (T) Str (M) (4) where F is the Frobenius norm of the matrix, and: (t) Str(M) : = 1 e d M(xi,xl) P sk S(t) tr e d M(xi,xk) yl e d M(xi,xl) P sk S(t) tr e d M(xi,xk) (yi yl) Matching and Estimation: To perform the step called Nearest Neighbor Matching, which produces Matched Groups that are used to estimate CATEs in Figure 1, we use the learned distance metric M(Str). To estimate conditional average treatment effects (CATEs) for each unit in the estimation set, we use its nearest neighbors from the same estimation set. Specifically, for any given unit s in the estimation set, we construct a K-nearest neighbor matched group MG(s, d M(Str), Sest, K) using control set S(C) est and treatment set Parikh, Rudin and Volfovsky S(T) est . For a choice of estimator φ, the estimated CATE for a treated unit s = (xs, ys, ts = t ) is calculated as follows: ˆτ(x) = φ MG(s, d M(Str), S(T) est , K) φ MG(s, d M(Str), S(C) est , K) . A simple example of φ is the empirical mean, i.e., φ MG(s, d M, S(t) n , K) = 1 k MG(s,d M,S(t) n ,K) However, one can choose the estimator to be a weighted mean, linear regression or a nonparametric model like Random Forest. Particular choices of φ can also play a role in bias-adjustment to improve the matching estimator of the ATE as in Abadie and Imbens (2011). For φ (MG(s, d M, Sn, K)) = P k MG(s,d M,Sn,K) f Wkyk, if f Wk is chosen to be proportional to ed M(x,xk), then it leads to multi-robust (defined shortly) and generalizable CATE estimates via soft KNN (as shown in Theorem 2 and Theorem 4 below), while letting f Wk be proportional to 1 sk KNNS(C) est M(Str) produces interpretable matched groups. Hyperparameter choice: MALTS has four main hyperparameters: 1) K, which is the number of nearest neighbors used to estimate the counterfactual, which can be chosen by cross-validation. 2) n, the size of training set, i.e., the size of the split on the left of Figure 1. This can be chosen based on the amount of data relative to the number of features, though typically we choose it to be 10% of the data. 3) The maximum allowed diameter or caliper to prune bad matched groups. If the matches have a larger diameter, the matches are not tight and we may not be able to trust their estimates. The maximum diameter can be chosen by domain knowledge; the user defines how far apart points can be to make the matched group less interpretable. 4) The number of repeats refers to the number of times we shuffle the data and re-partition it for MALTS training and estimation procedure. A larger number of repeats of the whole process helps with smoothing out the estimates over different train/test splits. 5. Robustness and Generalization of MALTS In this section, we show that the MALTS framework correctly estimates the distance metric, facilitating correct estimates of CATEs under SUTVA and a positivity assumption. After basic definitions, and after showing that the learned distance metric and objective are bounded, we introduce and define the concepts of multi-robustness and generalizability of the learned distance metric. Multi-robustness implies that for any possible pair of points the empirical average loss is not far away from the population average loss. Theorem 2 proves that the distance metric learned by the MALTS algorithm is multi-robust. We use these results along with the error bound shown in Lemma 3, to show that MALTS distance metric is generalizable, i.e., the population average loss and the empirical average loss on the observed data for the learned distance metric are close with high probability. Lastly, we show that MALTS distance metric is asymptotically generalizable and that the empirical average loss approaches the population average loss as the size of the dataset goes to infinity. Basic definitions of empirical loss and population loss. First, we define a pairwise loss for si and sl so that it is only finite for treatment-treatment or control-control matched pairs, loss[M, si, sl] := ( e d M(xi,xl)|yi yl| if ti = tl otherwise. This loss is high for pairs of points that are close (i.e., with small d M(xi, xl)) when the outcomes yi and yl values are very different. Further, we define an empirical average pairwise loss over finite sample Sn of size n as Lemp(M, Sn) := 1 (si,sl) (Sn Sn) loss[M, si, sl] and define an average loss over population Z as Lpop(M, Z) := Ezi,zl i.i.d µ(Z) h loss[M, zi, zl] i . The search space over distance metrics is bounded. We show a basic result about the optimization-based approach we take to learn the distance metric. Specifically, we show that the learned distance metric will be in a bounded region of search space. Now, because the learned M(Str) on the set Str is the distance metric that minimizes the given objective function, we know that the following inequality is true, which states that the learned parameter has a lower training objective than that of the trivial parameter 0: c M(Str) F + (C) Str (M(Str))+ (T) Str (M(Str)) c 0 F + (C) Str (0)+ (T) Str (0) =: g0. (6) Denoting the right hand side of the inequality by g0 we note that we can limit our search space over distance metrics M that satisfy the following inequality: The objective function terms are bounded. The objective terms (C) Str and (T) Str (defined in Equation (5)) for learning the distance metric are also bounded, although it is not that easy to see this directly because their denominators are somewhat complicated, involving a sum over exponential terms. Here, we point out that because the learned distance metric is bounded, the objective s terms ( (C) Str and (T) Str ) are also bounded. Specifically, their upper bound is proportional to the empirical average pairwise losses Lemp(M, S(C) tr ) and Lemp(M, S(T) tr ), defined above. Further, in Theorem 4, we show that for t {T, C} the empirical average loss Lemp(M, S(t ) tr ) is close to population average pairwise loss Lpop(M(Sn), Z(t )) with high probability. Following Equations (7) and (8) and Theorem 4, the objective terms (C) Str (M) and (T) Str (M) are upper-bounded by a term proportional to the population average pairwise loss with high probability. (C) Str (M) 1 e d M(xi,xl) P sk S(C) tr e d M(xi,xk) (yi yl) Parikh, Rudin and Volfovsky sl S(C) tr loss[M, si, sl] P sk S(C) tr e d M(xi,xk) . We know that: i, k d M(xi, xk) = (xi xk) M M(xi xk) 1/2 xi xk 2 M F g0C2 x c . Together, the two previous lines imply: (C) Str (M) 1 exp ( g0C2x c ) S(C) tr 2 X loss[M, si, sl] = Lemp(M, S(C) tr ) exp ( g0C2x Similarly for the treatment units, we have (T) Str (M) Lemp(M, S(T) tr ) exp ( g0C2x Now, we define a few concepts important for our results including covering number, multi-robustness, and generalizability. The following definitions and results closely align with the theoretical guarantees of distance metric learning algorithms in Bellet and Habrard (2015) and Xu and Mannor (2012). Our work extends these results to learn a distance metric for causal inference. Definition 2 (Covering Number) Let (U, d) be a metric space. Consider a subset V of U, then ˆV V is called a γ-cover of V if for any v V, we can always find a ˆv ˆV such that d(v, ˆv) γ. Further, the γ-covering-number of V under the distance metric d is defined by N(γ, V, d) := min |ˆV| : ˆV is a γ-cover of V . Note that N(γ, V, d) is finite if U is compact. Definition 3 (Robustness) A learned distance metric M( ) is (K, ϵ( ))-robust for a given K and ϵ( ) : (Z Z)n R, if we can partition X into K disjoint sets {Ci}K i=1 such that for any subsample Str and its corresponding pair set S2 tr := Str Str, we have for any pair of training units s1 = (x1, y1, t1), s2 = (x2, y2, t2) S2 tr, and for any pair of units in the support z1 = (x 1, y 1, t 1), z2 = (x 2, y 2, t 2) Z2, i, l {1, ..., K}, if x1, x 1 Ci and x2, x 2 Cl such that t1 = t 1 = t2 = t 2 then loss[M(Str), s1, s2] loss[M(Str), z1, z2] ϵ(Str). Intuitively, robustness means that for any possible unit in the support, the loss is not far away from the loss of nearby units in the training set, should some training units exist nearby. (This terminology is aligned with the distance metric learning literature, e.g., Bellet and Habrard 2015, Xu and Mannor 2012, and it is different from robustness to model misspecification that frequently appears in the causal inference literature in terms such as doubly robust estimator. ) Definition 4 (Multi-Robustness) A learned distance metric M( ) is (K, ϵ( ))-multirobust for a given K and ϵ( ) : Zn R, if we can partition X into K disjoint sets C = {Ci}K i=1 such that for any subsample Sn and its corresponding pair set S2 n := Sn Sn, we have s1 = (x1, y1, t1), s2 = (x2, y2, t2) S2 n, z1 = (x 1, y 1, t 1), z2 = (x 2, y 2, t 2) Z2, i, l {1, ..., K}, given d loss[M(Sn), C(t ) i , C(t ) l ] := 1 |C(t ) i ||C(t ) l | (si,sl) C(t ) i C(t ) l loss[M(Sn), s1, s2] and loss[M(Sn), C(t ) i , C(t ) l ] := E[loss(M, Zi, Zl) | X i C(t ) i , X l C(t ) l ] Ci, Cl C, d loss[M(Sn), C(t ) i , C(t ) l ] loss[M(Sn), C(t ) i , C(t ) l ] ϵ(Sn). Intuitively, multi-robustness means that for any possible pair of points from any two partitions of X, the empirical average loss over training points is not far away from the population average loss. As the training procedure aims at minimizing the total loss, we can safely say that a multi-robust method will not perform poorly out of sample. Definition 5 (Generalizability) A learned distance metric M( ) is said to generalize with respect to the given training sample Sn if Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) ϵ where δϵ is a decreasing function of ϵ with zero-intercept. Definition 6 (Asymptotic Generalizability) A learned distance metric M( ) is said to asymptotically generalize with respect to the given training sample Sn if Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) = 0 Given these definitions, we first show that the distance metric learned using MALTS is robust in Theorem 2 and we extend the argument to show that it is also generalizable in Theorem 4. Theorem 2 (MALTS learned distance metric is multi-robust) With probability greater than , the distance metric M( ) learned using MALTS N(γ, X, 2), β multirobust for arbitrary chosen values of γ > 0 and β 0, where B is maxz1,z2 loss(M(Sn), z1, z2), {Ci}K i=1 is the partition of X into non-empty sets Ci s such that K is the γ-covering number of X, C(t ) i = {zj = (xj, yj, tj) : tj = t , xj Ci} and ρ(t ) γ = mini |C(t ) i |. Parikh, Rudin and Volfovsky Proof (Theorem 2). Given Z = X Y T , we consider the following definition of a minimum sized γ-cover ˆV of the set X under the distance metrix 2: Partition the set into K disjoint subsets Cγ = {Ci}K i=1 such that K is the γ-covering-number of X under 2 (which is exactly equal to |ˆV|) where each Ci is contained in the γ-neighborhood of each ˆvi ˆV and each Ci contains at least one control and one treated sample. Note that if X is a compact convex set, then such a cover and the corresponding packing Cγ exists and K = |Cγ| is finite. For any arbitrary Ci and Cl in Cγ, consider the empirical average loss for all training units si Ci and sl Cl with treatment t : d loss h M(Sn), C(t ) i , C(t ) l i = 1 |C(t ) i C(t ) l | (si,sl) C(t ) i C(t ) l loss[M(Sn), si, sl] and the expected loss for units Zi and Zl: loss h M(Sn), C(t ) i , C(t ) l i = E h loss(M, Zi, Zl) | X i C(t ) i , X l C(t ) l i . Let f be a function of the set of independent random variables such that f(s1, . . . , s|C(t ) i |, s|C(t ) l |+1, . . . , s|C(t ) i |+|C(t ) l |) = 1 |C(t ) i C(t ) l | i=C(t ) i +1 loss[M(Sn), si, sl]. Thus, f(s1, . . . , s|C(t ) i |, s|C(t ) l |+1, . . . , s|C(t ) i |+|C(t ) l |) = d loss h M(Sn), C(t ) i , C(t ) l i . Now, let ρ(t ) γ be the density of the γ-cover for treatment t , defined as the number of units with treatment t in the smallest partition set ρ(t ) γ = mini |C(t ) i | and B = maxz1,z2 loss(M(Sn), z1, z2). Now, we show that f( ) has bounded difference. Without loss of generality, consider an index j |C(t ) i |, then |f(s1, . . . , sj, . . . , s|C(t ) i |+|C(t ) l |) f(s1, . . . , s j, . . . , s|C(t ) i |+|C(t ) l |)| |C(t ) i C(t ) l | |C(t ) i |+|C(t ) l | X i=|C(t ) i |+1 loss[M(Sn), si, sj] loss[M(Sn), si, s j] |C(t ) i C(t ) l | |C(t ) i |+|C(t ) l | X i=|C(t ) i |+1 loss[M(Sn), si, sj] loss[M(Sn), si, s j] |C(t ) i C(t ) l | |C(t ) i |+|C(t ) l | X i=|C(t ) i |+1 |loss[M(Sn), si, sj]| + loss[M(Sn), si, s j] |C(t ) i C(t ) l | B = B |C(t ) i | B Similarly, for any j > |C(t ) i |, |f(s1, . . . , sj, . . . , s|C(t ) i |+|C(t ) l |) f(s1, . . . , s j, . . . , s|C(t ) i |+|C(t ) l |)| 2B . As f() is a function of independent |C(t ) i | + |C(t ) l | random variables, by Mc Diarmid s inequality: P d loss h M(Sn), C(t ) i , C(t ) l i loss[M(Sn), C(t ) i , C(t ) l ] β P|C(t ) i |+|C(t ) l | i=1 B2 (ρ(t ) γ )2 2β2 ρ(t ) γ 2 (|C(t ) i | + |C(t ) l |)B2 β2 ρ(t ) γ 2 We will need the following lemma to prove Theorem 4. The lemma provides a bound for a particular treatment assignment, while the theorem sums over all treatment assignments. Lemma 3 (Error Bound) Given sample Sn i.i.d µ(Z) where n(t ) is the number of units with ti = t in Sn, and choosing B > 0 for which loss[ , zi, zl] B zi, zl Z (B is finite because X is compact and Y is bounded): if a learning algorithm provides a distance metric M(Sn) that is (K, ϵ( ))-multi-robust with probability pmr(ϵ), then for any E > 0, with probability greater than or equal to (1 E)(pmr(ϵ))K2 we have t T , Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) ϵ(S(t ) n )+2B 2K ln(2) + 2 ln(1/E) Theorem 4 (MALTS distance metric is generalizable) The distance metric M( ) learned using the data Sn and MALTS algorithm is generalizable and asymptotically generalizable, as follows: 1. Generalizability: With probability at least β2 ρ(t ) γ 2 with respect to the random draw of data, Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) 2|T |β + X 2K ln(2) + 2 ln(1/E) for arbitrary chosen constants γ > 0, E > 0, and β 0, where B is maxz1,z2 loss(M(Sn), z1, z2), {Ci}K i=1 is the partition of X into non-empty sets Ci s such that K is the γ-covering number of X, C(t ) i = {zj = (xj, yj, tj) : tj = t , xj Ci}, and ργ = mini,t |C(t ) i |. Parikh, Rudin and Volfovsky 2. Asymptotic Generalizability: Lpop(M(Sn), Z(C)) Lemp(M(Sn), S(C) n ) + Lpop(M(Sn), Z(T)) Lemp(M(Sn), S(T) n ) Now that we have theoretically proven the functionality of MALTS, we will next discuss and compare MALTS performance with other methods on different datasets. 6. Experiments In this section, we discuss and compare the performance of MALTS with other competing methods on a few different simulation setups with continuous covariates, discrete covariates and mixed (continuous and discrete) covariates. Lastly, we demonstrate MALTS performance for estimating ATE on La Londe s NSW and PSID-2 data samples (La Londe 1986, Dehejia and Wahba 1999). MALTS performs an η-fold honest causal inference procedure with the estimator φ inside each matched group being linear regression. We split the observed samples Sn into η equal parts such that the ratio of treated to control units in each part is similar. For each fold, we use one of the η partitions as the training set Str (not used for matching) and the rest of the η 1 partitions as the estimation set Sest. Using the output from each of the η folds, we calculate the estimated CATE for each unit (averaged across folds), estimated distance metric (averaged across folds) and a weighted unified matched group for each unit si Sn. The weight of each matched unit sk corresponds to the number of times a particular unit sk was in the matched group of unit si across the η 1 constructed matched groups. Here, η was chosen to be 5 in our experiments. For interpretability, we let Mc be a diagonal matrix, which allows stretches of the continuous covariates. (Note that Md, which is the stretch matrix over discrete covariates, is always set to be diagonal.) This way, the magnitude of an entry in Mc or Md provides the relative importance of the indicated covariate for the causal inference problem. We further analyzed strategies for variance estimation for MALTS in Section 6.8, and performance under limited overlap between the covariates distribution of treated and control groups, and sensitivity to unobserved confounding. Detailed results are shown in Appendix B. The main results of these experiments are that MALTS performance is on par with existing state-of-the-art methods for causal inference, including black box methods. MALTS tends to have fairly consistent performance, even if the training set is fairly small or the number of irrelevant covariates is large. Further, MALTS provides interpretable distance metrics and matched groups that black box machine learning methods do not provide. 6.1 Data Generation Processes In this subsection we describe the data generation process (DGP) used in the simulation experiments. We use two main data-generation processes: The first DGP has a linear baseline with linear and quadratic treatment effects while the second DGP is the extension of Friedman s function introduced to test performance of prediction algorithms of Friedman (1991). This second DGP, also termed as Friedman s DGP, has a scaled cosinusoidal treatment effect. 6.1.1 Quadratic DGP This simulation includes both linear and quadratic terms. Let xi,p = {xi,pc, xi,pd} be a pdimensional covariate vector composed of |pc| continuous covariates and |pd| discrete ones. There are k = kc kd relevant covariates and the rest of the dimensions are irrelevant. Here, pc, kc, pd, and kd refer to the the subsets of indices of the covariates: all continuous, relevant continuous, all discrete, and relevant discrete, respectively. xi,kc and xi,kd refer to the vectors of relevant continuous and discrete covariates respectively. xi,k refers to all |k| relevant covariates. κc kc is the set of continuous covariates and κd kd is the set of discrete which are relevant in determining the treatment choice. The potential outcomes and treatment assignment are determined as follows: xi,pc iid N(µ, Σ), {xi,j}j pd iid Bernoulli(ψ), ϵi,0, ϵi,1 iid N(0, 1), ϵi,treat iid N(0, σ2) s1, . . . , s|k| iid Uniform{ 1, 1}, αj|sj iid N(10sj, 9), β1, . . . , β|k| iid N(1, 0.25) j kc kd αjxi,j + ϵi,0 j kc kd αjxi,j + X j kc kd βjxi,j + X j kc kd xi,jxi,j + ϵi,1 j κc kc xi,j + X j κd kd xi,j (|κc|µ + |κd|ψ) + ϵi,treat yi = tiy(1) i + (1 ti)y(0) i . Here expit(z) = exp(z)/(1 + exp(z)). The variance of ϵi,treat determines how much confounding and overlap there is in the dataset: higher values of the variance make the dataset look like a randomized experiment with good overlap, while very small values of the variance lead to poor overlap and a very hard to analyze observational study. We explore these issues in detail in Appendix B. 6.1.2 Friedman s DGP The data generation process of Friedman (1991) was first proposed to assess the performance of prediction methods. We augmented Friedman s simulation setup to evaluate causal inference methods. The potential outcome under control is Friedman s function as provided by Friedman (1991) and Chipman et al. (2010). The expected treatment effect we study is equal to the cosine of the product of the first two covariates scaled by the third covariate. xi,1 . . . xi,10 iid U(0, 1), ϵi,0, ϵi,1 N(0, 1), ϵi,treat iid N(0, 1) y(0) i = 10 sin(πxi,1xi,2) + 20 (xi,3 0.5)2 + 10 xi,4 + 5 xi,5 + ϵi,0 Parikh, Rudin and Volfovsky y(1) i = 10 sin(πxi,1xi,2) + 20 (xi,3 0.5)2 + 10 xi,4 + 5 xi,5 + xi,3 cos(πxi,1xi,2) + ϵi,1 ti = 1 [expit(xi,0 + xi,1 0.5 + ϵi,treat) > 0.5] yi = tiy(1) i + (1 ti)y(0) i . 6.2 Continuous Covariates We use the data-generation process described in Section 6.1.1 to generate 2500 units with no discrete covariates, 15 important continuous covariates and 25 irrelevant continuous covariates. Further, we set the parameters for the DGP as follows: µ = 1, Σ = 1.5I, ψ = 0.5, σ2 = 1 and κc = {0, 1}. We estimate CATE for each unit using matching methods like propensity score matching, prognostic score matching and genetic matching, and non-matching (uninterpretable) methods like causal forest and BART. Figure 2 shows the performance of these methods. MALTS performance is on par with existing state-ofthe-art non-matching methods and outperforms all other matching methods for continuous covariates in the quadratic data generation process. Figure 2: MALTS performs well with respect to other methods for continuous data. Letterbox plots of CATE Absolute Error relative to the true ATE on the test set for several methods. 6.3 Discrete Covariates We use the data-generation process described in Section 6.1.1 to generate 2500 units with no continuous covariates, 15 important discrete covariates and 10 irrelevant discrete covariates. Further, we set the parameters of the DGP as follows: σ2 = 1, c = 2 and κd = {0, 1}. We used the weighted Hamming distance metric for this experiment. Figure 3 shows the performance comparison, again showing that MALTS performance is on par with existing state-of-the-art non-matching methods; it also performs better than FLAME (a state-of-the-art matching method for discrete data) as it is able to provide additional smoothing in this relatively small-n setting. Hence, MALTS performs well for discrete covariates in the quadratic data generation process. Figure 3: MALTS performs well with respect to other methods for discrete data. Letter-box plots of CATE Absolute Error relative to the true ATE on the test set for several methods. 6.4 Mixed Covariates We use the data-generation process used for experiments on continous and discrete covariates (described in Section 6.1.1) to generate 2500 units with 5 relevant continuous covariates, 15 relevant discrete covariates, 10 irrelevant continuous and 10 irrelevant discrete covariates. We used the same set of parameters for the DGP as the previous two experiments. Similar to the previous two experiments, Figure 4 shows that MALTS performs on par with the state-of-the-art non-matching methods and outperforms all matching methods that can handle mixed covariates for the quadratic data generation process. 6.5 Number of Covariates We studied the performance of various causal inference methods to estimate CATEs as the number of covariates (p) changes, keeping the number of relevant covariates (|k|) constant and equal to 8. We simulated the data using the DGP described in Section 6.1.1. The number of units is constant (n = 2048) while the number of covariates (p) changes from 8 to 256. The performance of MALTS is on-par with or better than other causal inference methods as the number of irrelevant covariates increases (see Figure 5). This indicates that MALTS can be used to help reduce the effects of the curse of dimensionality. Parikh, Rudin and Volfovsky Figure 4: MALTS performs well on data with mixed covariates. Letter-box plots of CATE Absolute Error relative to the true ATE on the test set for several methods. MALTS performs well on the setup with mixed (continuous+discrete) covariates. 6.6 Number of Units We studied the change in CATE estimation error-rates as the number of units in a dataset increases. We simulated the data using the DGP described in Section 6.1.1, keeping the number of covariates constant and equal to 20 (all of them are relevant in outcome determination). We changed the number of units from 28 to 212. MALTS performance is on-par with or better than BART and the error-rate is significantly lower than that of other causal inference methods (see Figure 6). 6.7 Friedman s Setup We further compare MALTS and other flexible methods performance on data generated using the process described in Section 6.1.2. This DGP is particularly interesting because the potential outcomes are highly non-linear functions with trigonometric expressions. As shown in Figure 7, we observe that MALTS performs on par with Causal Forest while BART s error-rate is significantly higher (worse) than MALTS, for the Friedman s data generation process. 6.8 Coverage Study We use the DGP described in Section 6.1.1 with 2 relevant continuous covariates and no irrelevant covariates for the coverage study. Further, we set the parameters to the DGP as follows: µ = 1, Σ = 1.5I, ψ = 0.5, and c = 2. We selected 9 reference points in a grid from the covariate space as shown in Figure 8(b) and conducted an experiment that considered these reference points, over 100 repetitions. We compared coverage for CATEs estimated Figure 5: MALTS performs on-par with other methods for a range of values of p. Comparative performance in estimating CATE using causal inference methods as the number of covariates increases, keeping the number of relevant covariates constant and equal to 8. The number of units is fixed: n = 211. (For the given n, BART does not return CATE estimates for some units when p > 26. Prognostic scores use BART for p 27 and gradient boosted trees for p > 26.) using MALTS for different values of the variance, ranging from 1.0 to 4.0, for noise term ϵ0 and ϵ1 in the potential outcomes function. Variance estimation is notoriously hard in matching problems, even for overall quantities such as the average treatment effect (Abadie and Imbens 2006). We consider both a conservative variance estimator (Wang et al. 2021) and estimators that sacrifice some interpretability for better coverage. Specifically, we consider the CATEs estimated using MALTS and study how well an uninterpretable method can predict those estimates to obtain a variance estimate. We use the predictive variance from gradient boosting regression, from gaussian process regression and from Bayesian ridge regression on the covariates, where we estimated CATEs and quantify variance of each CATE estimate. Based on Figure 8(a), the coverage for each the nine points of interest is between 0.85 and 1 for most values of the variance using any of the three variance estimation approaches. 6.9 La Londe Data The La Londe data pertain to the National Support Work Demonstration (NSW) temporary employment program and its effect on income level of the participants (La Londe 1986). This dataset is frequently used as a benchmark for the performance of methods for observational causal inference. We employ the male sub-sample from the NSW in our analysis as well as the PSID-2 control sample of male household-heads under age 55 who did not classify themselves as retired in 1975 and who were not working when surveyed in the spring of 1976 Parikh, Rudin and Volfovsky Figure 6: MALTS consistently performs on par with or better than non-interpretable approaches. Trend plots of average CATE Absolute Error for several methods, for different numbers of units in the datasets. Figure 7: MALTS performs well on Friedman s setup. Letter-box plots of CATE absolute error relative to true ATE for MALTS and other causal inference methods. (Dehejia and Wahba 1999). The outcome variable for both experimental and observational analyses is earnings in 1978 and the considered variables are age, education, whether a respondent is Black, is Hispanic, is married, has a degree, and their earnings in 1975. Previously, it has been demonstrated that almost any adjustment during the analysis of the experimental and observational variants of these data (both by modeling the outcome and Figure 8: (a) Coverage of 95 percent confidence interval for 9 points: (1.0,1.0), (2.5,2.5), (-0.5,-0.5), (2.5,-0.5), (-0.5,2.5), (4.0,4.0), (-3.0,-3.0), (4.0,-3.0) and (-3.0,4.0). (b) Covariate space showing positions of 9 points-of-interest as black-stars, with other points color-coded according to their treatment assignments. Parikh, Rudin and Volfovsky by modeling the treatment variable) can lead to extreme bias in the estimate of average treatment effects (La Londe 1986). Table 2: Estimated ATE for different methods on Lalonde s NSW experimental dataset. The MALTS estimate of ATE is closer to the true ATE than other methods. We provide estimates for MALTS before and after pruning the matched groups with large diameters. The threshold to prune was chosen by rule of thumb on diameters of matched groups as shown in Figure 9(b). ATE Estimate Estimation Bias (%) Method Truth 886 - MALTS 881.67 -0.49 MALTS (pruned) 888.53 0.29 Gen Match 859.72 -2.97 Propensity Score 513.30 -42.06 Prognostic Score 943.81 6.52 BART-CV 1164.72 31.46 Causal Forest-CV 509.32 -42.51 Table 3: Estimated ATE for different methods on Lalonde s NSW experimental data and PSID-2 observational dataset. We provide estimates for MALTS before and after pruning the matched groups with large diameters. The threshold to prune was chosen by rule of thumb on diameters of matched groups as shown in Figure 9(b). ATE Estimate Estimation Bias (%) Method Truth 886 - MALTS 608.37 -31.34 MALTS (pruned) 891.75 0.65 Gen Match 549.53 -37.98 Propensity Score 513.79 -42.01 Prognostic Score -897.76 -201.33 BART-CV 713.20 -19.50 Causal Forest-CV -179.98 -120.31 Performance results: Tables 2 and 3 present the average treatment effect estimates based on MALTS, state-of-the-art modeling methods, and matching methods. MALTS (after appropriately pruning low-quality matched groups) is able to achieve accurate ATE estimation on both experimental and observational datasets. Figure 9 illustrates how the matched groups were pruned. There was a clear visual separation between high-quality matched groups, which had low diameters, and low-quality matched groups, with larger diameters. Model Interpretability: One difference between MALTS and the other methods is that its solution can be described concisely: MALTS produces a total of seven numbers that define the distance metric on the La Londe data. The distribution of the learned distance metric values across folds is shown in Figure 9(a). Once the researcher has these seven numbers, along with the value of k in k-nearest neighbors used to train MALTS, they know precisely which units should be matched. In contrast, causal forest and BART require a model whose size depends on the number of trees, where each tree is several levels deep in this case, 2000 trees and 150 trees, respectively. Interpretability of Matched Groups: To examine the interpretability of MALTS matched groups, we present two of the matched groups from MALTS for the observational Lalonde dataset in Table 4, corresponding to two query individuals in the dataset. Query 1 is a 22 year old with no income in 1975. MALTS was able to construct a tight matched group for this individual (both in control and in treatment). In contrast, Query 2 is a 42year-old high-income individual without a degree, which is an extremely unlikely scenario, leading to a matched group with a very large diameter, which should probably not be used during analysis. Such granular analysis is not possible for regression methods like BART and matching methods like prognostic score or propensity score matching. This further highlights the troubleshooting capabilities of interpretable matching methods: by identifying units that are poorly matched, we know exactly which units to study in more detail. In this case, it is possible that the degree field might have a data error, which means it would be better not to match this unit and to potentially follow up on the veracity of responses to the survey. 7. Conclusion and Discussion This paper introduces the MALTS algorithm, which learns a distance metric on the covariate space for use with matching. The learned metric stretches important covariates and compresses irrelevant covariates for outcome prediction in order to produce high-quality matches. Unlike other methods, MALTS can handle a large number of irrelevant covariates by compressing them to the point where they are effectively eliminated, which helps handle the curse of dimensionality. Unlike black-box machine learning methods, MALTS produces interpretable matched groups and returns the stretch matrix on covariates for counterfactual prediction. The stretch matrix is chosen here to be diagonal, so that it can be represented using only a few stretch numbers that determine the importance of each covariate in determining the matched groups. Whereas deep neural networks mainly show improvements over other methods for problems that do not have natural data representations (computer vision, speech, etc.), we conjecture that the stretch/almost-exact match combination should suffice for most datasets. A natural extension, however, is to use neural networks to learn a flexible distance metric in a latent space, thus allowing us to match on medical records, images, and text documents. This will allow us to incorporate complex data structures by introducing a flexible learning framework (e.g., interpretable neural networks) for coding the data. That is, we can redefine the distance metric via d M(xi, xj) = ωM(xi), ωM(xj) or d M(xi, xj) = (ωM(xi) ωM(xj))2 , Parikh, Rudin and Volfovsky Table 4: Learned distance metric and examples of matched-groups on Lalonde Experimental treatment and Observational control datasets for two example query points drawn from the same datasets. Query 1 represents a high quality (low diameter) matched group while Query 2 represents a poor quality (high diameter) matched group that could be discarded during analysis. Stretch Matrix Age Education Black Hispanic Married No-Degree Income-1975 mean(Diag(M)) 0.780 1.786 1.254 1.110 1.205 1.229 1.001 std(Diag(M)) 0.361 0.778 0.641 0.577 0.614 0.618 0.512 Two Matched Groups Unit-ID Treated Age Education Black Hispanic Married No-Degree Income-1975 Income-1978 Query-1: 1 Yes 22 9 No Yes No Yes $0 $3595 94 Yes 23 8 No Yes No Yes $0 $3881 330 No 22 8 No Yes No Yes $0 $9920 299 No 22 9 Yes No No Yes $0 $0 5 Yes 22 9 Yes No No Yes $0 $4056 82 Yes 21 9 Yes No No Yes $0 $0 416 No 22 9 Yes No No Yes $0 $12898 333 No 21 9 Yes No No Yes $0 $3343 292 Yes 20 9 Yes No No Yes $0 $8881 17 Yes 23 10 Yes No No Yes $0 $7693 116 Yes 24 10 Yes No No Yes $0 $0 Unit-ID Treated Age Education Black Hispanic Married No-Degree Income-1975 Income-1978 Query-2: 968 No 42 11 No No Yes Yes $44758 $54675 274 Yes 35 9 Yes No Yes Yes $13830 $12803 141 Yes 25 8 Yes No No Yes $37431 $2346 967 No 50 17 No No Yes No $30435 $25860 948 No 35 12 No No Yes No $26854 $29554 210 Yes 25 8 No No No Yes $23096 $6421 241 Yes 24 15 Yes No No No $13008 $14683 311 No 28 12 Yes No Yes No $29009 $10067 183 Yes 23 10 Yes No No Yes $15709 $5665 182 Yes 23 12 Yes No Yes No $15079 $10283 where ωM is a summary of relevant data features learned using a complex modeling framework. In the future, the MALTS framework could be extended to deal with missing covariates, and can be adapted to instrumental variables. Acknowledgments We gratefully acknowledge funding from the National Science Foundation under grants III 1703431, CCF 1934964, IIS 2130250, IIS 2147061 (with Amazon), and CAREER DMS 2046880, and the National Institute of Health under grants NIDA DA054994 and R01EB025021. We also acknowledge funding from an Amazon Graduate fellowship. A. Abadie and G. W. Imbens. Large sample properties of matching estimators for average treatment effects. Econometrica, 74(1):235 267, 2006. A. Abadie and G. W. Imbens. Bias-corrected matching estimators for average treatment effects. Journal of Business & Economic Statistics, 29(1):1 11, 2011. A. Bellet and A. Habrard. Robustness and generalization for metric learning. Neurocomputing, 151: 259 267, 2015. V. Chernozhukov, D. Chetverikov, M. Demirer, E. Duflo, C. Hansen, W. Newey, and J. Robins. Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21(1):C1 C68, 01 2018. H. A. Chipman, E. I. George, and R. E. Mcculloch. BART: Bayesian additive regression trees. Annals of Applied Statistics, pages 266 298, 2010. W. G. Cochran and D. B. Rubin. Controlling bias in observational studies: A review. Sankhy a: The Indian Journal of Statistics, Series A, pages 417 446, 1973. R. H. Dehejia and S. Wahba. Causal effects in nonexperimental studies: Reevaluating the evaluation of training programs. Journal of the American Statistical Association, 94(448):1053 1062, 1999. A. Dieng, Y. Liu, S. Roy, C. Rudin, and A. Volfovsky. Interpretable almost-exact matching for causal inference. Proceedings of Machine Learning Research (Proceedings of AISTATS), 89: 2445, 2019. V. Dorie, H. Chipman, R. Mc Culloch, A. Dadgar, R. C. Team, G. U. Draheim, M. Bosmans, C. Tournayre, M. Petch, R. de Lucena Valle, et al. Package dbarts . 2019. M. H. Farrell. Robust inference on average treatment effects with possibly more covariates than observations. Journal of Econometrics, 189(1):1 23, 2015. J. H. Friedman. Multivariate adaptive regression splines. The Annals of Statistics, pages 1 67, 1991. J. Goldberger, G. E. Hinton, S. T. Roweis, and R. R. Salakhutdinov. Neighbourhood components analysis. In Advances in Neural Information Processing Systems, pages 513 520, 2005. X. S. Gu and P. R. Rosenbaum. Comparison of multivariate matching methods: Structures, distances, and algorithms. Journal of Computational and Graphical Statistics, 2(4):405 420, 1993. P. R. Hahn, J. S. Murray, and C. M. Carvalho. Bayesian regression tree models for causal inference: regularization, confounding, and heterogeneous effects. Bayesian Analysis, 15(3), September 2020. B. B. Hansen. The prognostic analogue of the propensity score. Biometrika, 95(2):481 488, 2008. J. L. Hill. Bayesian nonparametric modeling for causal inference. Journal of Computational and Graphical Statistics, 20(1):217 240, 2011. D. E. Ho, K. Imai, G. King, and E. A. Stuart. Match It: Nonparametric preprocessing for parametric causal inference. Journal of Statistical Software, 42(8):1 28, 2011. S. M. Iacus, G. King, and G. Porro. Causal inference without balance checking: Coarsened exact matching. Political Analysis, 20(1):1 24, 2012. G. W. Imbens. Nonparametric estimation of average treatment effects under exogeneity: A review. Review of Economics and Statistics, 86(1):4 29, 2004. H. Jiang. Non-asymptotic uniform rates of consistency for k-nn regression. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 33, pages 3999 4006, 2019. N. Kallus. A Framework for Optimal Matching for Causal Inference. In A. Singh and J. Zhu, editors, Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, volume 54 of Proceedings of Machine Learning Research, pages 372 381, Fort Lauderdale, FL, USA, 20 22 Apr 2017. Parikh, Rudin and Volfovsky L.-Z. Kara, A. Laksaci, M. Rachdi, and P. Vieu. Data-driven k NN estimation in nonparametric functional data analysis. Journal of Multivariate Analysis, 153:176 188, 2017. L. Keele and J. R. Zubizarreta. Optimal multilevel matching in clustered observational studies: A case study of the school voucher system in Chile. Journal of the American Statistical Association, 112(518):547 560, 2017. R. J. La Londe. Evaluating the Econometric Evaluations of Training Programs with Experimental Data. American Economic Review, 76(4):604 620, September 1986. M. Morucci, V. Orlandi, S. Roy, C. Rudin, and A. Volfovsky. Adaptive hyper-box matching for interpretable individualized treatment effect estimation. Conference on Uncertainty in Artificial Intelligence (UAI), 2020. M. Morucci, M. Noor-E-Alam, and C. Rudin. A robust approach to quantifying uncertainty in matching problems of causal inference. INFORMS Journal on Data Science, 2022. accepted. H. Parikh, C. Rudin, and A. Volfovsky. An application of matching after learning to stretch (MALTS) to the ACIC 2018 causal inference challenge data. Observational Studies, 5:118 130, 2019. H. Parikh, K. Hoffman, H. Sun, W. Ge, J. Jing, R. Amerineni, L. Liu, J. Sun, S. Zafar, A. Struck, et al. Why interpretable causal inference is important for high-stakes decision making for critically ill patients and how to do it. ar Xiv preprint ar Xiv:2203.04920, 2022. M. Resa and J. R. Zubizarreta. Evaluation of subset matching methods and forms of covariate balance. Statistics in Medicine, 2016. P. R. Rosenbaum. Imposing minimax and quantile constraints on optimal matching in observational studies. Journal of Computational and Graphical Statistics, 26(1), 2017. P. R. Rosenbaum and D. B. Rubin. The central role of the propensity score in observational studies for causal effects. Biometrika, 70(1):41 55, 1983. D. B. Rubin. Matching to remove bias in observational studies. Biometrics, pages 159 183, 1973a. D. B. Rubin. The use of matched sampling and regression adjustment to remove bias in observational studies. Biometrics, pages 185 203, 1973b. D. B. Rubin. Multivariate matching methods that are equal percent bias reducing, I: Some examples. Biometrics, pages 109 120, 1976. D. B. Rubin. Causal inference using potential outcomes: Design, modeling, decisions. Journal of the American Statistical Association, 100:322 331, 2005. C. J. Stone. Consistent nonparametric regression. The Annals of Statistics, pages 595 620, 1977. E. A. Stuart. Matching methods for causal inference: A review and a look forward. Statistical Science, 25(1):1, 2010. T. Wang, M. Morucci, M. U. Awan, Y. Liu, S. Roy, C. Rudin, and A. Volfovsky. FLAME: A fast large-scale almost matching exactly approach to causal inference. Journal of Machine Learning Research, 22(31):1 41, 2021. K. Q. Weinberger and L. K. Saul. Distance metric learning for large margin nearest neighbor classification. Journal of Machine Learning Research, 10(2), 2009. K. Q. Weinberger, J. Blitzer, and L. K. Saul. Distance metric learning for large margin nearest neighbor classification. In Advances in Neural Information Processing Systems, pages 1473 1480, 2006. H. Xu and S. Mannor. Robustness and generalization. Machine Learning, 86(3):391 423, 2012. Z. Zhao. Using matching to estimate treatment effects: Data requirements, matching metrics, and monte carlo evidence. The Review of Economics and Statistics, 86(1):91 107, 2004. J. R. Zubizarreta. Using mixed integer programming for matching in an observational study of kidney failure after surgery. Journal of the American Statistical Association, 107(500):1360 1371, 2012. J. R. Zubizarreta, R. D. Paredes, and P. R. Rosenbaum. Matching for balance, pairing for heterogeneity in an observational study of the effectiveness of for-profit and not-for-profit high schools in Chile. The Annals of Applied Statistics, 8(1):204 231, 2014. Appendix Appendix A. In this section we provide proofs for theorems and lemmas discussed in Section 5. Proof (Lemma 3). If (D1, . . . , DK) is the multinomially distributed random vector with parameters d and p1, . . . , p K then, by the Bretagnolle-Huber-Carol inequality, Thus, for our case, we can consider Ni corresponding to the set of indices of units in sample S(t ) n such that their x s are contained in the partition Ci as in Theorem 2. Hence, by the Bretagnolle-Huber-Carol inequality, we know that n(t ) µ(Ci) 2K ln(2) + 2 ln(1/E) Now, for some arbitrary t T let us consider Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) . We know that Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) Ez1,z2[loss(M(Sn), z1 = (x 1, y 1, t 1), z2 = (x 2, y 2, t 2)) | x 1 Ci, x 2 Cj] µ(Ci)µ(Cj) 1 (n(t ))2 X s1,s2 S(t ) n loss(M(Sn), s1, s2) Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] µ(Ci)µ(Cj) Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] µ(Ci) |Nj| Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] µ(Ci) |Nj| i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] |Ni| n(t ) |Nj| n(t ) i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] |Ni| n(t ) |Nj| n(t ) Parikh, Rudin and Volfovsky 1 (n(t ))2 X s1,s2 S(t ) n loss(M(Sn), s1, s2) i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] µ(Ci) µ(Cj) |Nj| i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] |Nj| i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] |Ni| n(t ) |Nj| n(t ) 1 (n(t ))2 X s1,s2 S(t ) n loss(M(Sn), s1, s2) n(t ) µ(Ci) + i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] |Ni| n(t ) |Nj| n(t ) 1 (n(t ))2 X s1,s2 S(t ) n loss(M(Sn), s1, s2) where B is max z1,z2 loss(M(Sn), z1, z2). Recall, M(Sn) is (K, ϵ( ))-multi-robust with probability pmr(ϵ). Thus, PK i,j=1 Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] |Ni| n(t ) |Nj| n(t ) 1 (n(t ))2 P s1,s2 S(t ) n loss(M(Sn), s1, s2) ϵ(S(t ) n ) Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] 1 |Ni||Nj| P s1,s2 S(t ) n loss(M(Sn), s1, s2) ϵ(S(t ) n )/K2 ! Ez1,z2[loss(M(Sn), z1, z2) | x 1 Ci, x 2 Cj] 1 |Ni||Nj| P (s1,s2) (Ci Cj) loss(M(Sn), s1, s2) ϵ(S(t ) n )/K2 ! (pmr(ϵ/K2))K2 Hence, by combining the above results, we can conclude for all t T we have Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) ϵ(S(t ) n ) + 2B q 2K ln(2) + 2 ln(1/E) 1 (1 E)(pmr(ϵ/K2))K2 . Lemma 5 (Used for proof of Theorem 1) Let {Sn} n=1 be a sequence of nested datasets, each of which includes n i.i.d. samples from µ(Z), n = 1.. . Given a smooth distance metric d M, covariate vector x, and α > 0, if there exists a small enough value of a and a large enough value of N such that K(t ) n (x) = {zk : d M(xk, x) < a, tk = t , zk Sn} is non-empty and α > δd M(a) for all n N then, P(|E[Y (t )|x] b Y (t ) x | > α) exp( |K(t ) n (x)|(α δd M(a))2/2Cy) where δd M(a) is the bound from Definition 1 (definition of smooth distance metric). As the above choice of a holds for all n N, we have that the bound goes to zero as n . Proof (Lemma 5). K(t ) n (x) is a matched group of nearest neighbors zk of unit z such that d M(x, xk) < a and treatment indicator tk = t , i.e., K(t ) n (x) = {zk : d M(x, xk) < a and tk = t }. We estimate the conditional average potential outcome for treatment choice t and X = x as 1 |K(t ) n (x)| P zk K(t ) n (x) Yk. If d M is a smooth distance metric then as all the units in K(t ) n (x) have distance to x less than a, for every zk in K(t ) n (x), we have |E[Y (t )|X = x] E[Y (t )|X = xk]| < δd M(a). Consider α such that δd M(a) < α and α Cy, then E[Y (t )|X = x] 1 |K(t ) n (x)| P zk K(t ) n (x) Yk 1 |K(t ) n (x)| P zk K(t ) n (x)(E[Y (t )|X = x] Yk) > α |K(t ) n (x)| zk K(t ) n (x) (E[Y (t )|X = x] Yk) zk K(t ) n (x) E[Y (t )|X = x] E[Y (t )|X = Xk] + E[Y (t )|X = Xk] Yk and using the triangle inequality, zk K(t ) n (x) E[Y (t )|X = x] E[Y (t )|X = Xk] + zk K(t ) n (x) E[Y (t )|X = Xk] Yk > |K(t ) n (x)|α By the definition of smooth distance metric, zk K(t ) n (x) zk K(t ) n (x) E[Y (t )|X = Xk] Yk > |K(t ) n (x)|α zk K(t ) n (x) E[Y (t )|X = Xk] Yk > |K(t ) n (x)|(α δd M(a)) Parikh, Rudin and Volfovsky and by Hoeffding s inequality, 2|K(t ) n (x)|(α δd M(a))2 |K(t ) n (x)|(α δd M(a))2 E[Y (t )|X = x] 1 |K(t ) n (x)| zk K(t ) n (x) |K(t ) n (x)|(α δd M(a))2 As n , for a constant a (and hence constant δd M(a)), a constant α, and letting the number of units matched to the target unit go to infinity: |K(t ) n (x)| , we have |E[Y (t )|X = x] 1 |K(t ) n (x)| zk K(t ) n (x) Lemma 6 (Also used for proof of Theorem 1) If we can estimate the conditional aver- age potential outcomes using a finite sample Sn i.i.d µ(Zn) such that for all t , E[Y (t )|X = x] and the estimate, ˆY (t ) x are farther than ϵ with probability less than δ (ϵ , n) for any given z = (x, y, t) Z and t T , then the estimated conditional average treatment effect ˆτ(x) using a finite sample Sn i.i.d µ(Zn) and the true conditional average treatment effect τ(x) are farther than ϵ with probability less than 2δ ( ϵ t T , PSn µ(Zn) | ˆY (t ) x E[Y (t )|X = x]| ϵ δ (ϵ , n) = PSn µ(Zn) |ˆτ(x) τ(x)| ϵ 2δ ϵ Proof (Lemma 6). We are given in the statement that for any ϵ > 0, we can find a δ (ϵ , n) such that we can estimate outcomes well, i.e., z Z, t T , PSn µ(Zn)(| ˆY (t ) x E[Y (t )|X = x]| ϵ ) δ (ϵ , n). We can further deduce from the union bound that | ˆY (t ) x E[Y (t )|X = x]| ϵ ! |T | δ (ϵ , n). (9) By the triangle inequality, we also know that ˆY (t ) x E[Y (t )|X = x] ˆY (t ) x E[Y (t )|X = x] . (10) From Equation 9, we have | ˆY (t ) x E[Y (t )|X = x]| |T |ϵ ! |T | δ (ϵ , n). Applying the triangle inequality from Equation 10, ˆY (t ) x E[Y (t )|X = x] |T |ϵ ! |T | δ (ϵ , n). Considering the case where T = {0, 1}, ˆτ(x) τ(x) 2ϵ ! 2δ (ϵ , n). Hence, we can conclude that ˆτ(x) τ(x) ϵ Proof (Theorem 1). The proof of Theorem 1 follows directly by substituting the result of Lemma 5 into Lemma 6. Proof (Theorem 4). By Theorem 2, we know that the distance metric M( ) learned using MALTS is (N(γ, X, 2), β)-multirobust with probability more than 1 exp Also, inferring from Lemma 3, for any arbitrary t T and E > 0 we have that with probability at least β2 ρ(t ) γ 2 with respect to the random draw of data to form Sn, we have Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) β + 2B 2K ln(2) + 2 ln(1/E) Let ργ = mint ρ(t ) γ . Then, summing over all possible t T we have that with probability at least: β2 ρ(t ) γ 2 Parikh, Rudin and Volfovsky with respect to the random draw of data to form Sn, we have: Lpop(M(Sn), Z(t )) Lemp(M(Sn), S(t ) n ) 2|T |β + X 2K ln(2) + 2 ln(1/E) γ in Theorem 2 was arbitrary, allowing us to take it to 0 in such a way that K increases at a rate smaller than mint n(t ) increases, and ργ increases at rate faster than or equal to O(n) to as n approaches . Thus we can reduce β2 to 0 at a rate slower than 1 ρ2γ . E was also set arbitrarily, allowing us to take it to 0 slowly enough such that as n , each of the n(t ) and we thus have: Lpop(M(Str), Z(t )) Lemp(M(Str), S(t ) tr ) Appendix Appendix B. Limited Overlap and Performance We use the DGP described in Section 6.1.1 with 2 relevant continuous covariates and no irrelevant covariates for the limited overlap experiments. Further, we set the parameters of the DGP as follows: µ = 1, Σ = 1.5, φ = 0.5, and c = 2. We performed experiments on the overlap by changing the standard deviation of the noise term ϵtreat from 0.001 to 100 in the treatment assignment equation of the DGP and measured CATE estimation error for MALTS in comparison with other methods for each of the scenarios. A lower variance leads to small overlap, i.e., large standardized difference of means, whereas large variance creates small standardized difference of means and high overlap. Figure 10 shows the relationship of ϵtreat and standardized difference of means between the covariate set of treated and control units. Figure 11 shows the performance of MALTS in comparison with other methods in predicting CATEs for multiple dataset sizes n {500, 2000, 4000} with p = 20, and different levels of overlap. MALTS performance is largely insensitive to limited overlap, however, the performance deteriorates if the control and the treated units are very different. The primary reason for MALTS deterioration of performance under almost nooverlap compared to BART is because matching methods like MALTS can be conceptualized as interpolation, unlike regression approaches that explicitly models the potential outcomes surfaces, which is closer to extrapolation. Sensitivity Analysis We performed a sensitivity analysis of MALTS on a data generation setup with a constant treatment effect, two observed relevant covariates, and an unobserved confounder affecting the probability distributions of outcome as well as the choice of treatment. The unobserved confounder has a linear relationship with the outcome, with the value of the coefficient equal to the sensitivity parameter of the outcome (γY ) and the choice of treatment with the value of the coefficient equal to the sensitivity parameter of the treatment (γT ). xi,1, xi,2, ui iid N(0, 1) ϵi,0, ϵi,1 iid N(0, 1) y(0) i = xi,1 + xi,2 + γY ui + ϵi,0 y(1) i = xi,1 + xi,2 + γY ui + 1 + ϵi,1 ti = Bernoulli (expit (xi,1 + xi,2 + γT ui 2)) yi = tiy(1) i + (1 ti)y(0) i Figure 12 shows the contour plot of ATE estimates produced by MALTS as we change the sensitivity parameters in the data generation process. Here, the true ATE is equal to 1. The plot indicates that as long as the unmeasured confounders are approximately half as important as either of the two observed covariates (γT or γY is below 0.5), MALTS performance is stable. At the extreme, which is when the unobserved confounder is as important as the total of the two observed covariates (γT or γY is 2), the performance (expectedly) degrades. Appendix Appendix C. In this section, we discuss our implementation of existing causal inference methods like genmatch, propensity score matching, BART, causal forest, difference of random forest, prognostic score matching and FLAME. In Section 6, we compare the performance of each of these methods with MALTS. We used Match It s implementation of genmatch and propensity score matching as it is commonly used by empiricists (Ho et al. 2011). We allowed matching with replacement for creating match groups and estimating CATEs. As Match It returns only match groups and CATE estimates for treated units (and not control units), then in order to estimate CATEs for control units we flipped the sign of the treatment indicators and estimated negative CATEs (we have to estimate negative CATE in this case because we flipped the sign of the treatment indicator causing CATE estimates to become negative CATE estimates). We merged the CATE estimates for the treated units and control units to get the CATE estimates for every unit in the dataset. We used the causal forest algorithm as implemented in the grf package in R. The settings for causal forest were set to the default designed by the grf developer with number of trees equal to 2000 and p + 20 variables tried for each split. We performed the same 5-fold CATE estimation procedure for causal forest, analogous to the one used for estimating CATEs using MALTS. We estimated CATEs for both the treated and control units in each estimation set. We used Vincent Dorie s R implementation of BART (Dorie et al. 2019). We performed the same 5-fold CATE estimation using BART that we used for MALTS. For each of the η folds, we trained two BART models, one for learning the response function for estimating the potential outcome under control and the other response function for estimating potential outcome under treatment using the training set. The CATEs were estimated by taking the difference of estimated response functions of treated and control units in the estimation set. We also implemented a 5-fold FLAME CATE estimation procedure analogous to the one used by MALTS. Parikh, Rudin and Volfovsky Lastly, we implemented 5-fold prognostic score matching using a random forest approach to model the prognostic score function. We fit a model for control units and a model for treated units using the data in the training set. To estimate the CATE for a treated unit in the estimation set, we found k-nearest neighbors in the control set with a similar estimated prognostic score. Analogously, we estimated the CATEs for the control units in the estimation set using the k-nearest treated units with similarity measured using the prognostic score. Appendix Appendix D. In this section, we show expanded matched groups (including all treatment and control units, not just control units) using propensity score matching, prognostic score matching and MALTS for unit id-1 from Table 5. MALTS and prognostic score matching are implemented as described in Appendix C with K = 10 nearest neighbors, where K was selected by crossvalidation for MALTS and used also for the other methods. Propensity score matching was implemented using the Match It package and operationalized with 10 nearest neighbor matching. Here, MALTS produces a high quality matched group, as shown in Table 4. While MALTS matches units based on a learned distance metric over the covariate space, prognostic score matching matches units based on a single prognostic value and propensity score matching matches units based on a single propensity value. We note that there is only one matched unit in common between the matched groups from prognostic score matching and MALTS (unit 116) and there are two units in common between the matched groups from propensity score matching and MALTS (units 330 and 416); the matched groups are almost entirely different between the three methods. Appendix Appendix E. In this section, we study the computation time required by MALTS to learn an optimal distance metric and estimate CATEs using matching. We study this by increasing the number of covariates from 8 to 136 (where the number of relevant covariates is 8 and others are irrelevant) keeping the number of samples constant at 2048. We compare the performance of MALTS for unrestricted stretch matrices (referred as full M ) with a case when the stretch matrix is restricted to diagonal matrices (referred as diagonal M ). Figure 13 shows that the difference in the computational time between the two cases increases approximately quadratically as the number of covariates increases. This is because inverting full M is O(p2) more costly than inverting a diagonal M. Table 5: Example Matched Group using (a) our approach, (b) prognostic score (Hansen 2008), and (c) propensity score matching (Rosenbaum and Rubin 1983) for a query unit in the Lalonde dataset (top rows). It matched closely on almost all covariates such as age, education, marital status, whether the person had an academic degree, and income in 1975. In contrast, prognostic and propensity scores did not match closely on factors such as education, age, marital status and income. Bold is used to denote disagreement between the query unit and its matched group. Treatment Covariates Outcome Unit ID Treated Age Education Black Hispanic Married No-Degree Income-1975 Income-1978 Query: 1 Yes 22 9 No Yes No Yes $0 $3596 (a) Our Approach (MALTS) 94 Yes 23 8 No Yes No Yes $0 $3881 330 No 22 8 No Yes No Yes $0 $9921 299 No 22 9 Yes No No Yes $0 $0 5 Yes 22 9 Yes No No Yes $0 $4056 82 Yes 21 9 Yes No No Yes $0 $0 416 No 22 9 Yes No No Yes $0 $12898 333 No 21 9 Yes No No Yes $0 $3343 292 Yes 20 9 Yes No No Yes $0 $8882 17 Yes 23 10 Yes No No Yes $0 $7693 116 Yes 24 10 Yes No No Yes $0 $0 (b) Prognostic Scores 154 Yes 22 10 Yes No No Yes $1071 $7315 56 Yes 30 11 Yes No Yes Yes $0 $591 100 Yes 17 10 Yes No No Yes $0 $0 109 Yes 18 9 Yes No No Yes $0 $4483 141 Yes 25 8 Yes No No Yes $37432 $2347 286 Yes 23 12 No Yes No No $1117 $559 338 No 44 9 Yes No No Yes $0 $9722 340 No 22 12 Yes No No No $532 $1333 355 No 18 10 No Yes No Yes $0 $1859 116 Yes 24 10 Yes No No Yes $0 $0 (c) Propensity Scores 416 No 22 9 Yes No No Yes $0 $12898 451 No 22 8 Yes No No Yes $0 $1391 330 No 22 8 No Yes No Yes $0 $9921 407 No 20 12 Yes No No No $1371 $20893 626 No 18 10 Yes No No Yes $2682 $0 774 No 21 13 No No No No $693 $2660 402 No 22 11 Yes No Yes Yes $0 $1698 925 No 21 12 Yes No No No $716 $22166 879 No 22 12 Yes No Yes No $0 $665 788 No 22 12 Yes No Yes No $0 $0 Parikh, Rudin and Volfovsky (a) Distance metric learned on Lalonde data across 250 folds 50 repeats and 5 splits within each repeat. Here, on an average, education is stretched more than other variables, which means it is more important to match closely on education. (b) Criteria for pruning low-quality matched groups with large diameter. Figure 9: (a) Box-plot of distance metric stretch values corresponding to each covariate in Lalonde data learned over 5 folds and 20 repeats. (b) Criteria to prune lowquality matched groups with large diameter from Lalonde data. Figure 10: Standardized difference of means between covariates of treated and control units decreases as std(ϵtreat) increases. We increase the value of std(ϵtreat) in the DGP for treatment allocation which increases the overlap in the treated and control groups. We generate data with p equal to 20 covariates and for values of n {500, 2000, 4000}. Parikh, Rudin and Volfovsky Figure 11: (a) Trend plot comparison of MALTS performance measured as mean relative error for CATE estimation under different levels of overlap, measured as a function of standard deviation of ϵtreat (the scale of noise in treatment allocation process). Higher values of the standard devation of ϵtreat corresponds to more overlap between the control and the treated groups. (b) Scatterplot comparing MALTS performance, as mean CATE estimation error, under different levels of overlap. Overlap is measured as standardized difference of means for n = 4000. Larger values of standardized difference of means corresponds to less overlap between the control and treated groups. Figure 12: Sensitivity analysis contour plot of the ATE estimation using MALTS. The best performance is when there is no unmeasured confounding, which is when both sensitivity parameters are 0. When the sensitivity parameter for the outcome is 2, the unmeasured confounder is approximately as important as the total of the two known covariates and performance degrades. Parikh, Rudin and Volfovsky Figure 13: Run time for MALTS when distance metric is constrained to have diagonal M, compared with distance metric where M is a full-rank positive semi-definite matrix.