Generalization Bounds and Representation Learning for Estimation of Potential Outcomes and Causal Effects
Fredrik D. Johansson, Uri Shalit, Nathan Kallus, David Sontag
GGeneralization Bounds and Representation Learning forEstimation of Potential Outcomes and Causal Effects
Fredrik D. Johansson ∗ , Uri Shalit , Nathan Kallus ,David Sontag Chalmers University of Technology Technion, Israel Institute of Technology Cornell Tech Massachusetts Institute of Technology
Abstract
Practitioners in diverse fields such as healthcare, economics and education are eagerto apply machine learning to improve decision making. The cost and impracticality ofperforming experiments and a recent monumental increase in electronic record keepinghas brought attention to the problem of evaluating decisions based on non-experimentalobservational data. This is the setting of this work. In particular, we study estimationof individual-level causal effects, such as a single patient’s response to alternative medi-cation, from recorded contexts, decisions and outcomes. We give generalization boundson the error in estimated effects based on distance measures between groups receivingdifferent treatments, allowing for sample re-weighting. We provide conditions underwhich our bound is tight and show how it relates to results for unsupervised domainadaptation. Led by our theoretical results, we devise representation learning algorithmsthat minimize our bound, by regularizing the representation’s induced treatment groupdistance, and encourage sharing of information between treatment groups. We extendthese algorithms to simultaneously learn a weighted representation to further reducetreatment group distances. Finally, an experimental evaluation on real and syntheticdata shows the value of our proposed representation architecture and regularizationscheme.
Evaluating intervention decisions is a key question in many diverse fields including medicine,economics, and education. In medicine, an optimal choice of treatment for a patient in theintensive care unit may mean the difference between life and death. In public policy, jobreforms have impact on the unemployment rate and the economy of a nation. To evaluatesuch interventions, we must study their causal effect —the difference in an outcome of inter-est under alternative choices of intervention. Since only one option may be carried out ata time, any data to support such evaluations only reveals the outcome of the action takenand never the outcome of the action not taken, which remains an unknown counterfactual . ∗ Correspondence to: [email protected] a r X i v : . [ c s . L G ] J a n o estimate causal effects, we must therefore infer what would have happened had we madeanother decision. Furthermore, to decide on personalized interventions, such as tailoringtreatments to patients, we must understand individual-level causal effects, conditioned onthe available information on an individual recorded prior to intervention.In this work, we focus on estimating individual-level causal effects from non-experimental, observational data. An observational dataset consists of historical records of interventions,the contexts in which they were made, and the observed outcomes. Our running example isthat of patients represented by their medical history, the medication they were prescribedand the outcome of treatment, such as mortality. An individual-level effect measures thecausal effect of medication choice, conditioned on what is known about the patient. Finally,though we know which interventions took place, the policy by which interventions werechosen in this data is typically unknown to us.Working with observational data is our best bet when experiments such as random-ized controlled trials (RCT) are infeasible, impractical or prohibitively expensive. Whilecheaper and easier to implement, observational studies come with new, fundamental diffi-culties. Perhaps the most challenging of these is confounding —influence of variables thatare causal of both the intervention and the outcome, and may introduce spurious, non-causal correlations between the two. For example, richer patients might both have moreaccess to certain medications and have better outcomes regardless of medication, makingsuch medications appear better than they might be. Similarly, job training might only begiven to those motivated enough to seek it. Na¨ıve estimates of causal effects may thereforebe biased by subsuming the effect of confounding variables on the outcome. Here, we makethe common assumption that confounding variables, such as wealth or motivation in theexamples above, have been measured and can be adjusted for in our estimation. This, how-ever, introduces another difficulty, which is contending with the systematic differences insuch variables between different treatment groups. Moreover, if these groups only partiallyoverlap in terms of variables causal of the outcome, consistent estimation of causal effects(estimates that converge asymptotically to the true effect) may not always be guaranteed.Causal estimation from observational data has been studied extensively in the statistics,econometrics, and computer science literature, but has until fairly recently been focused on average effects in a population or on simple models of heterogeneity such as linear regression.With sights set on personalization based on rich data, more flexible models are required, andmachine learning is more often considered for the task. When we can no longer make strongassumptions such as linearity or low dimensionality, new questions arise: How well will ourmodels generalize? How should we regularize them? What assumptions are necessary forgood performance guarantees or asymptotic consistency? What can be said when theseassumptions are not met? In this work, we begin to answer these questions.Sample weighting plays an important role in methods for estimating both average andindividual-level (conditional) effects (Rosenbaum and Rubin, 1983). At the heart of suchapproaches is the propensity score —the probability for a subject to receive treatment underthe observed policy, given their characteristics. In practice, the propensity score is typicallyunknown. Replacing the propensity score with estimates thereof is prone to introducebias due to model misspecification or variance due to small sample sizes. In this work, we This has some notable exceptions such as in advertising, in which an existing policy for serving ads wasdesigned and known to the advertiser (Swaminathan and Joachims, 2015; Lefortier et al., 2016).
Notation and terminology
Random variables are denoted with capital roman letters
A, B, C, . . . and observationsthereof with a corresponding indexed lower-case letter a i , b i , c i , . . . . The empirical densityof a draw of m samples from a density p is denoted ˆ p m . Unless stated otherwise, all randomvariables are distributed according to a fixed distribution p ( A, B, ... ). Expectations over avariable X distributed according to p ( X ) are denoted E X [ · ], and conditional expectationsover X given Y distributed according to p ( X | Y ), E X | Y [ · | Y = t ]. When expectations aredefined w.r.t. a density q , different from p , the notation E X ∼ q [ · ] is used. We introduce the problem of estimating conditional average treatment effects (CATE) fromobservational data. Throughout the paper, we adopt the running example of estimatingthe effect of a medical treatment on a patient. This informs our choice of terminology andnotation and serves to give intuition for the mathematical quantities involved. However, theapplicability of the theory and algorithms described are in no way limited to this application.We consider having a simple random sample with replacement of size m from a popu-lation distributed according to p . Using the Neyman-Rubin potential outcome framework(Imbens and Rubin, 2015), we associate each unit i = 1 , . . . , m with the following variables:3 p p µ ( x ) µ ( x ) τ ( x ) (a) Outcomes and CATE Xµ ( x ) f ∗ ( p ) f ∗ ( p ) p p (b) Misspecified models Xµ ( x ) f ∗ ( p ) p ( x )( µ ( x ) − f ∗ ( x )) p p (c) Error and risk Figure 1: Illustrative example of bias in regression adjustment of expected potential out-comes µ ( x ) , µ ( x ) and CATE τ ( x ). In (a), we show the two potential outcomes and the twotreatment groups p , p in dashed blue and solid red lines, respectively, as well as samples ofeach group. In (b), we show the best linear models ˆ m ( p ) , ˆ m ( p ) of the potential outcomeunder treatment µ ( x ) fit to the potential outcome of the control group and treated grouprespectively. In (c), we illustrate the difference in weighted error (bias) for the model fit tothe treated group ˆ m ( p ) evaluated in the control group and treated group. • An observed context X i ∈ X ⊆ R d , defined by all information observed about a patientbefore the choice of treatment is determined. These covariates may influence both thetreatment choice and the outcome of an experiment. The context is represented as a d -dimensional real-valued vector throughout this paper. • An observed treatment T i ∈ T = { , } , which is an intervention performed in anobserved context. Treatments are assumed to be binary variables throughout thispaper, where T i = 1 is usually referred to as “treatment” and T i = 0 as “control.” • An outcome Y i ∈ Y ⊆ R , measuring an aspect of interest of a patient, such as bloodpressure or mortality, after the administration of a treatment, represented by a real-valued variable. • Unobserved potential outcomes Y i (0) , Y i (1) ∈ Y ⊆ R that correspond to the outcomesthat would have been observed for unit i under treatments T i = 0 and T i = 1, respec-tively. We assume that Y i = Y i ( T i ) throughout the paper, capturing both consistency and non-interference , also known in conjunction as the stable unit treatment valueassumption (SUTVA), which is key to the existence of potential outcomes and therelevance of this hypothetical construct to the actual data (Rubin, 2005).We drop the subscript i when dealing with the distribution of any single such randomdraw from this population. Note that since potential outcomes are unobserved, our dataconsists just of ( X , T , Y ) , . . . , ( X m , T m , Y m ), and Y (0) , Y (1) remain unobserved. Thesedata are assumed to be sampled iid from some population distribution p ( X, T, Y ). Followinga long tradition, we refer to the conditional density p ( X | T = 0) as the control group and p ( X | T = 1) as the treatment group . With slight abuse of terminology, we use these labelsalso in reference to the empirical quantities. For convenience, for t ∈ { , } , we introducethe short-hands p t ( X ) := p ( X | T = t ) and ˆ p t ( X ) analogously.4he potential outcome of the treatment not administered, Y i (1 − T i ), is an unobservedcounterfactual , which is the key impediment to assessing the individual treatment effect, Y i (1) − Y i (0). That we do not observe what would have happened if we did somethingdifferently is often termed the fundamental problem of causal inference . We are interestedin estimating the following quantities. Definition 1.
The expected potential outcomes µ t , conditioned on X = x , is µ t ( x ) := E Y ( t ) [ Y ( t ) | X = x ] , for t ∈ { , } , (1)and the conditional average treatment effect (CATE) given a context x is τ ( x ) := E Y (0) ,Y (1) | X [ Y (1) − Y (0) | X = x ] = µ ( x ) − µ ( x ) . (2)The CATE is an object of key interest as it tells us what is the best prediction of the effect onan individual given only their context variables. This has a variety of uses. One importantuse is the personalization of treatment to make sure that the treatment is effective for thetarget. We illustrate µ , µ and τ in Figure 1. Aside: Individual treatment effect
The conditional average treatment effect condi-tioned on everything that is known about a subject captures the individual-level causaleffects rather the population-level causal effects; it is therefore sometimes called individualtreatment effect (ITE) (Johansson et al., 2016; Shalit et al., 2017). While this terminol-ogy aligns with concepts used in machine learning, it overloads an existing definition incausal inference of the ITE as the difference Y i (1) − Y i (0). The distinction between thisITE and CATE is that ITE is unique to an individual and may not be described exactlyby any set of features X . For this reason, we adopt the more precise label CATE for thefeature-conditional treatment effect function τ ( x ). Potential outcomes and causal effects are said to be identifiable if they can be uniquelycomputed from the distribution p ( X, T, Y ) of the observed data. This is important becausethat distribution is the most one could ever hope to learn from an iid sample from p , and soanything that cannot be learned from the distribution cannot be learned from iid samplesfrom it. Without additional assumptions, τ ( x ) may not be identifiable. To see this, consideran observational study where treatment was given only to subjects over the age of 30, andthe control group consists only of subjects under the age of 30. If age has an effect on theoutcome of interest, there is no guarantee that it can be estimated from such data.Sufficient conditions for identification have been studied in both in very general set-tings (Pearl, 2009; Rubin, 2005) and in special cases that are commonly accepted in real-world applications. In parts of this work, we adopt and refer to the following assumptions. Assumption 1 (Ignorability) . The potential outcomes Y (0) , Y (1) and the treatment T areconditionally independent given X , Y (0) , Y (1) ⊥⊥ T | X . no unmeasured confounders assumption as it holds only if all confounding variables, which affect both treatment andpotential outcomes, are included in the observed variable X . Assumption 2 (Overlap) . In any context x ∈ X , any treatment t ∈ { , } has a non-zeroprobability of being observed in the data ∀ x ∈ X , t ∈ { , } : p ( T = t | X = x ) > . Overlap is sufficient to ensure that knowledge of the outcomes in one treatment groupmay be generalized to the opposite group given access to a large enough sample size. Notethat overlap only requires that the supports of the treatment groups are equal, not thatthey have similar densities. The degree to which treatment group densities are equal onthis support is sometimes referred to as balance .Under Assumptions 1 and 2 and SUTVA ( Y = Y ( T )), the conditional average treatmenteffect is identifiable as can be seen by the simple identity τ ( x ) = E Y (0) ,Y (1) | X [ Y (1) − Y (0) | X = x ]= E Y | X [ Y | X = x, T = 1] − E Y | X [ Y | X = x, T = 0] . Remark 1 (Plausibility of the identifiability assumptions) . Both ignorability and SUTVAare assumptions that are fundamentally untestable given observational data alone. Despitethis, they are often made in practice to justify subsequent analysis, or make clear its poten-tial limitations. A common heuristic motivation for ignorability is related to the richnessof the variable X . The richer the data, the more likely are they to cover all confound-ing variables. It should be noted, that even if all confounders are measured, adjusting forsome of them may introduce additional estimation bias or variance nonetheless (Ding et al.,2017). Furthermore, the overlap assumption becomes increasingly difficult to both satisfyand check as the dimensionality of X grows (D’Amour et al., 2017). Research into causal inference from observational data may be broadly grouped into twodistinct categories: causal discovery and causal effect estimation. In the former, the di-rection and presence of causal relationships between observed variables is unknown, andthe task is to infer them from data (Geiger et al., 2015; Spirtes and Glymour, 1991; Hoyeret al., 2009; Eberhardt, 2008; Hyttinen et al., 2014; Silva et al., 2006). In the latter, whichis the setting of this work, the structure of causal relationships is assumed to be known:confounders X are causal of treatment T and outcome Y ; treatment T is causal of Y (un-less the effect is 0); any unmeasured variable is causal only of either X , T or Y . We areprimarily interested in estimating the conditional average treatment effect of the treatment T on the outcome Y conditioned on the context X (Johansson et al., 2016; Shalit et al.,2017; Athey and Imbens, 2016; Wager and Athey, 2017; Pearl, 2017; Abrevaya et al., 2015;Bertsimas et al., 2017; Green and Kern, 2010; Alaa and Schaar, 2018).Estimation of causal effects from observational data is mostly performed under theassumptions of ignorability , treatment group overlap and consistency , as they are otherwise6enerally unidentifiable . In this work, we are motivated by the ignorability assumptionthroughout, but give several results that hold in its absence. For work on CATE estimationwithout ignorability, see e.g., Kallus and Zhou (2018); Kallus et al. (2018b,a); Louizos et al.(2017); Rosenbaum (2002). In contrast, we focus on the case where overlap is only partiallysatisfied. Lack of overlap is widely acknowledged as a problem (D’Amour et al., 2017) butestimation in this setting has received considerably less attention in the literature.Under ignorability, CATE is given by the difference of two regressions: the expectedoutcome among treated units given covariates minus the expected outcome among controlunits given covariates. Estimating CATE by fitting two separate regressions is sometimesknown as “T-learner” (K¨unzel et al., 2017), where T stands for two . A simpler approachis learning one regression model from the covariates X and the treatment T to the out-come Y . CATE is then estimated by evaluating the difference between the prediction for X, T = 1 and
X, T = 0. This approach is known as “S-learner”, where S stands for single .It has been argued that these methods are prone to compounding bias when applied in high-dimensional, small-sample settings that require significant regularization (Nie and Wager,2017). Rather than estimating these regressions jointly or separately, a variety of work hasstudied directly estimating their difference, e.g. using trees (Athey and Imbens, 2016) andforests (Wager and Athey, 2017). Other work has studied meta-learners that combine differ-ent base learners for the underlying regression functions using methods which are differentfrom simple differencing (Robins et al., 2000; K¨unzel et al., 2017; Nie and Wager, 2017). Alarge body of work has shown that under the assumption of having a well-specified (con-sistent) model for each regression, CATE estimation is asymptotically consistent, efficient,and/or asymptotically normal (Chernozhukov et al., 2017; Belloni et al., 2014).Results proving asymptotic consistency provide little insight into the case of modelmisspecification—what if we do not know a parametric class of functions that can exactlyfit the outcome in terms of high-dimensional baseline covariates and treatment? A lineof research that clearly addresses model misspecification in the setting of standard super-vised learning is agnostic learning . Agnostic machine learning focuses on finding best-in-class models and bounding the generalization error of any model, whether well-specifiedor not (Vapnik, 2013; Cortes et al., 2010). However, in the causal inference setting, undermodel misspecification, regression methods may suffer from additional bias when general-izing across populations subject to different treatments. Our work addresses this issue byextending specification-agnostic generalization bounds to the CATE estimation problem.These bounds motivate our algorithms in the same way that standard supervised learninggeneralization bounds motivate structural risk minimization (Vapnik, 1998).A complement to regression estimation of CATE is importance sampling , where the goalis to alleviate systematic differences in baseline covariates across treatment groups. Thisidea is used in propensity-score methods (Austin, 2011; Rosenbaum and Rubin, 1983), thatuse the observed treatment policy to re-weight samples for causal effect estimation, andmore generally in re-weighted regression, see e.g. (Freedman and Berk, 2008). Two majordrawbacks of these methods are the need to estimate the propensity score when it is un-known, and the high variance introduced when the propensities are so that small estimationerrors lead to dividing by near-zeros. To address this, others (Gretton et al., 2009; Kallus,2016, 2017), have proposed learning sample weights to minimize a distributional distancebetween samples, but rely on specifying the data representation a priori, without regard7or which aspects of the data matter for outcome prediction. We build on the importancesampling literature by developing theory and algorithms for weighted risk minimization forpotential outcomes and CATE, both in a fixed representation and one learned from data.Our work on representation learning has conceptual ties to the idea of the prognosticscore (Hansen, 2008). A prognostic score is any function Φ( X ) of the context X that Markovseparates Y (0) and X , such that Y (0) ⊥⊥ X | Φ( X ). An extreme example is Φ( X ) = X . If Y (0) follows a generalized liner model, then Φ( X ) = E [ Y (0) | X ] is also a prognostic score.The prognostic score is a form of dimension reduction which under certain assumptions issufficient for causal inference. Note that unlike the propensity score, the prognostic scoremight very well be vector valued. One can view our approach as attempting to find approx-imate non-linear prognostic functions for both Y (0) and Y (1). We stress the approximate,because in fact we trade off how well our learned representation Φ( · ) is sufficient to explainthe potential outcomes with a balancing objective which we show is important for good finite-sample estimation of CATE. Our goal is to accurately estimate the CATE, τ in eq. (2), by an estimator ˆ τ that dependson the data, without making parametric assumptions on the functional form of the true po-tential outcomes. For this reason, we adopt the risk minimization approach to learning andsearch for best-in-class hypotheses for τ , rather than striving for point identification. Thisrequires generalization of both the treated and control outcomes to the general population.In this section, we derive bounds on the risk of such an estimator in the following steps:1. We define prediction risk with respect to potential outcomes and relate it to theexpected error in estimates of CATE2. We show how the risk on the observed distribution is a biased estimate of the desiredmarginal risk and give sample re-weighting schemes that removes this bias3. We give bounds on the expected risk under imperfect re-weighting schemes by placingassumptions on the loss with respect to the true outcome4. We derive finite-sample versions of these bounds and combine them to form a singlefully observable bound on the risk in estimates of potential outcomes and CATE.Our main generalization bound does not depend on treatment group overlap (Assump-tion 2). This diverges from most theoretical results for treatment effect estimation andprovides intuition for when we can expect extrapolation to succeed approximately. Con-sistent non-parametric estimation, however, still requires overlap. In Section 5 we extendthese results to the representation learning setting. We study prediction of potential outcomes Y ( t ), for t ∈ { , } , using hypotheses f t ∈ H ,for some class of hypotheses H . These hypotheses are then combined to form estimatesˆ τ ( x ) = f ( x ) − f ( x ) .
8e note that while this is not the only way to estimate τ (see e.g., (Robins et al., 2000;K¨unzel et al., 2017; Nie and Wager, 2017) for alternatives), it does allow us to leverageseparate bounds on the risk of hypotheses f , f with respect to the potential outcomes, tothen bound the risk of ˆ τ . We define the risk of hypotheses f , f below. Definition 2.
Let L : Y × Y → R + be a loss function, such as the squared loss L ( y, y (cid:48) ) = ( y − y (cid:48) ) . The expected pointwise loss of a hypothesis f t at a point x is: (cid:96) f t ( x ) := E Y | X [ L ( Y ( t ) , f t ( X )) | X = x ] . (3)The marginal risk of a hypothesis f t w.r.t. a population p is R ( f t ) := E X [ (cid:96) f t ( X )] , (4)and the factual risk of f t w.r.t. treatment group p ( X | T = t ) is R t ( f t ) := E X | T [ (cid:96) f t ( X ) | T = t ] . (5)The counterfactual risk is R − t ( f t ) := E X | T [ (cid:96) f t ( X ) | T = 1 − t ]. The subscript onthe risk R t indicates the treatment group over which it is evaluated. Note that thepotential outcome against which the risk is evaluated is implicit in this notation—weonly consider evaluating f t against Y ( t ) or µ t , ˆ τ against τ , et cetera.In most of this work, we restrict our attention to the squared error loss but note that ouranalysis generalizes to other convex loss functions, such as the mean absolute deviation.Similar to potential outcomes, we assess the quality of an estimate ˆ τ of τ based on theexpectation of the loss function L over the marginal density of covariates, p ( X ), R (ˆ τ ) := E X [ L ( τ ( X ) , ˆ τ ( X ))] . (6)The marginal risk R (ˆ τ ) is the overall expected error in estimating CATE, taken overthe entire population. However, R (ˆ τ ) is not readily computable from data because neither τ ( X ) nor p ( X ) are known. Moreover, we cannot make an empirical average estimate of itbecause, again, neither τ ( X i ) nor Y i (1) − Y i (0) are known. Instead, we will bound R (ˆ τ )from above. The main challenge of computing the marginal risk for hypotheses of potentialoutcomes is to quantify the counterfactual risk, and this is the primary concern of this work.Unlike R ( f t ), R (ˆ τ ) does not depend on the noise (conditional variance) in Y ( t ) . Weadopt this convention for R (ˆ τ ) as it coincides with the Precision in Estimation of Het-erogeneous Effects (PEHE) (Hill, 2011). However, similarly to L ( Y ( t ) , f t ), L ( τ, ˆ τ ), is notobserved over p , as Y (1) is only observed for the treated group p , and Y (0) only for thecontrol group p . We return to this issue later, and begin instead by stating the followingresult relating the risk of ˆ τ to those of f and f , in the case of L the squared loss. Lemma 1.
Let L ( y, y (cid:48) ) = ( y − y (cid:48) ) be the squared loss function. For hypotheses f , f of expected potential outcomes µ , µ , with marginal risks R ( f ) , R ( f ) , and ˆ τ = f − f , This is because that τ is defined in terms of expectations over the potential outcomes. here is a constant σ Y (defined in the proof below), such that R (ˆ τ ) ≤ R ( f ) + R ( f )) − σ Y . (7) Similar results hold for metric losses, e.g., the absolute loss, L ( y, y (cid:48) ) = | y − y (cid:48) | .Proof. Due to the relaxed triangle inequality for squared differences, E X [( τ ( X ) − ˆ τ ( X )) ] = E X [( µ ( X ) − µ ( X ) − f ( X ) + f ( X )) ] ≤ (cid:0) E X [( µ ( X ) − f ( X )) ] + E X [( µ ( X ) − f ( X )) ] (cid:1) Now, by the standard bias-noise decomposition, R ( f t ) = E X,Y (cid:2) (( Y ( t ) − µ t ( x )) (cid:3) + E X (cid:2) ( µ t ( x ) − f t ( x ))) (cid:3) . Hence, for t ∈ { , } with σ Y ( t ) = E X,Y (cid:2) (( Y ( t ) − µ t ( x )) (cid:3) , E X [( τ ( X ) − ˆ τ ( X )) ] = 2 (cid:0) R ( f ) − σ Y (0) + R ( f ) − σ Y (1) (cid:1) and with σ Y = min( σ Y (0) , σ Y (1) ), we have our result.Lemma 1 implies that small errors in hypotheses of potential outcomes guarantee smallerrors in CATE. However, it is worth noting that this decomposition need not lead to thebest achievable bound in all cases. Even when Y (0) and Y (1) are complex functions, τ ( x )may be a simple function. In this work, we do not address this in our theoretical treatmentbut find that sharing parameters in estimation of Y (0) and Y (1) lead to better resultsempirically. We proceed to study estimation of R ( f ) and R ( f ) separately, in terms ofobservable quantities, to later give a self-contained result for R (ˆ τ ). We proceed to show how the marginal risk R in potential outcomes and CATE may becomputed by re-weighting the factual risk R t . This approach is widely used within machinelearning (Shimodaira, 2000; Cortes et al., 2010) and statistics (Rosenbaum and Rubin,1983). We note in passing that R t is not observed directly but can be readily estimatedfrom an empirical sample. We return to this issue in later sections.Due to the fundamental impossibility of observing counterfactual outcomes, each poten-tial outcome Y ( t ) is only observed for subjects who were given treatment T = t , distributedaccording to p ( X | T = t ). As a result, unless treatment is assigned randomly (indepen-dently of X ), R t ( f t ) is a biased estimate of R ( f t ) in general. In particular, a minimizer f ∗ t of R t ( f t ) can be arbitrarily different from a minimizer of R ( f t ), depending on the differencebetween p and p t . This bias can have large impact on treatment policies derived from f , f and τ . We illustrate this problem in Figure 1.To reduce the bias described above, a classical solution is to re-weight the observablerisk (Shimodaira, 2000) using a function w : X → R n + , such that E X | T [ w ( X ) | T = t ] = 1, R wt ( f t ) := E X | T [ w ( X ) (cid:96) f t ( X ) | T = t ] , (8)10here w is chosen to skew the sample to mimic the distribution of the full population p .Many common choices of weights are based on the family of balancing scores (Rosenbaumand Rubin, 1983), of which the best known is the propensity score η ( X ) with respect to X , η ( x ) := p ( T = 1 | X = x ) . (9)We can now state the following result. Lemma 2.
For fixed t ∈ { , } , under Assumption 2 (overlap), there exists a weightingfunction w : X → R + , such that R wt ( f t ) = R ( f t ) . In particular, this holds for w ( x ) := p ( T = t )(2 t − η ( x ) −
1) + 1 − t . (10) More generally, it holds for (10) with η ( φ ( x )) for any φ such that (cid:96) f ⊥⊥ X | φ ( X ) . Werefer to weights that satisfy (10) as balancing weights .Proof. For any weighting function w , R wt ( f t ) = (cid:90) x ∈X w ( x ) (cid:96) f t ( x ) p t ( x ) dx = (cid:90) x ∈X p t ( x ) p ( x ) w ( x ) (cid:96) f t ( x ) p ( x ) dx = R ptp w ( f t ) . With w ( x ) = p ( x ) /p t ( x ), the special case in (10) follows from Bayes theorem and thedefinition of η ( x ). The second step uses Assumption 2 to ensure that p t ( x ) /p ( x ) is defined.The more general statement follows from integration over φ and a change of variables. Remark 2 (Violation of Assumptions 1 & 2) . If overlap is only partially satisfied, Lemma 2may still be applied to the expected risk over the subset of X for which overlap holds. Moregenerally, weights may be chosen to emphasize regions where treatment groups are moresimilar (Li et al., 2018). If ignorability is violated, such as when unobserved confoundersexist, a consistent estimator could in theory be obtained by letting the weights w dependalso on Y . However, such weights are not identified from observed data. Instead, a worst-case bound may be obtained by searching over a family of weighting functions in which theseoptimal weights are members. This is the topic of sensitivity analysis (see e.g., Rosenbaum(2002) for a comprehensive overview).In practice, η ( X ) and balancing weights w are typically unknown and have to be esti-mated from data. Moreover, even though weights based on η are optimal in expectation,they can lead to poor finite-sample behavior (Swaminathan and Joachims, 2015). For thesereasons, even if we had knowledge of η ( X ), we are often interested in weighting functionsthat do not satisfy Lemma 2. Next, we give bounds on the difference between the re-weighted(factual) empirical risk, under arbitrary weightings, and the marginal risk. When overlap is not satisfied everywhere or the chosen weighting function w is not perfectlybalancing, the difference between the weighted factual risk R wt ( f t ) and the marginal risk R ( f t ) may be arbitrarily large, without further assumptions on the potential outcomes or11 X p ‘ f t ( x ) µ t ( x ) f t ( x ) (a) Misspecified hypothesis f t − X sup x,x | ‘ f t ( x ) − ‘ f t ( x ) || x − x | ‘ f t ( x ) (b) Lipschitz loss (cid:96) f t − X p p ∝ ‘ ∗ ( p − p ) (c) Bound on loss difference Figure 2: Example illustrating assumptions on the pointwise loss (cid:96) f t . In (a) we see the truepotential outcome µ t and a hypothesis f t . The pointwise loss between them is plotted in(b). In (c), we illustrate the difference between two densities p and p on {− , , } . Thebottom panel shows the worst-case contribution of any loss function in an rbf-kernel RKHS L to the difference in risk R ( f t ) − R ( f t ). The more similar p , p , or the smoother thefunctions in L , the smaller the overall contribution.the hypothesis class H . However, in many cases we have reason to make assumptions aboutthe worst-case loss in generalization, as is typical in statistical learning theory. In thissection, we give bounds on R ( f t ) under such assumptions.Let L ⊂ {X → R + } be a space of pointwise loss functions with respect to the covariates X endowed with a norm (cid:107) · (cid:107) L . In this work, we assume that the expected conditional loss (cid:96) f t for each potential outcome belongs to such a family, i.e., that (cid:96) f t ∈ L . A simple exampleof such a family is the set of loss functions with bounded maximum value, L M = { (cid:96) : → R + ; sup x ∈X (cid:96) ( x ) ≤ M } . This assumption is satisfied without loss of generality as longas the outcome Y is bounded. However, it is not very informative and will lead to loosebounds in general. Instead, we may make assumptions about the functional properties of (cid:96) f t . Such assumptions include that (cid:96) f t is C -Lipschitz or belongs to a reproducing-kernelHilbert space (RKHS). We illustrate the former with an example in Figures 2a–2b.Now, consider the marginal distribution p and a re-weighted treatment group p wt on X .Let (cid:96) ∈ L be a pointwise loss on X . Recall that R ( f t ) and R wt ( f t ) denote the marginal andre-weighted factual risks respectively. By definition, R ( (cid:96) ) = R wt ( (cid:96) ) + (cid:90) x ∈X (cid:96) ( x )( p ( x ) − p wt ( x )) dx ≤ R wt ( (cid:96) ) + sup (cid:96) (cid:48) ∈L (cid:12)(cid:12)(cid:12)(cid:12)(cid:90) x ∈X (cid:96) (cid:48) ( x )( p ( x ) − p wt ( x )) dx (cid:12)(cid:12)(cid:12)(cid:12) . (11)The second term on the right-hand side in (11) is known as the integral probability metric distance (IPM) between p and p wt w.r.t. L , defined as follows (M¨uller, 1997) :IPM L ( p, q ) := sup (cid:96) ∈L | E p [ (cid:96) ( x )] − E q [ (cid:96) ( x )] | . (12)Particular choices of L make the IPM equivalent to different well-known distances on distri-butions: With L the family of functions in the norm-1 ball in a reproducing kernel Hilbert12pace (RKHS), IPM L is the Maximum Mean Discrepancy (MMD) (Gretton et al., 2012);When L is the family of functions with Lipschitz constant at most 1, we obtain the Wasser-stein distance (Villani, 2008). As pointed out by e.g., Long et al. (2015), the IPM maybe viewed as the loss of a treatment group classifier, and adversarial losses may be consid-ered in its place (Ganin et al., 2016). In Figure 2c, we illustrate the maximizer (cid:96) ∗ of thesupremum, in terms of its contribution to the expected difference in risk in the MMD case.Before stating the final form of our bounds on R ( f t ) − R wt ( f t ), we note that for t ∈ { , } ,with π t = p ( T = t ), we may decompose the population risk R as follows. R ( f t ) = π t R t ( f t ) (cid:124) (cid:123)(cid:122) (cid:125) Observable +(1 − π t ) R − t ( f t ) (cid:124) (cid:123)(cid:122) (cid:125) Unobserved . (13)The factual risk R t ( f t ) is identifiable under ignorability, as (cid:96) f t ( X ) = E Y ( t ) | X [ L ( f t ( X ) , Y ( t )) | X ] = E Y | X [ L ( f t ( X ) , Y ) | X, T = t ]For this reason, to bound the risk of f t on the whole population p it is sufficient for us tobound the counterfactual risk R − t ( f t ), and estimate R t ( f t ) empirically. Lemma 3.
For a hypothesis f with expected point-wise loss (cid:96) f ( x ) such that (cid:96) f / (cid:107) (cid:96) f (cid:107) L ∈L and treatment groups p , p , there exists a re-weighting w such that, R ( f ) − R w ( f ) ≤ (cid:107) (cid:96) f (cid:107) L IPM L ( p , p w ) ≤ (cid:107) (cid:96) f (cid:107) L IPM L ( p , p ) . (14) The first inequality is tight under Assumption 2 for weights w ( x ) = p ( x ) /p ( x ) . Thesecond is not tight for general f unless p = p . An equivalent result holds for R ( f ) .Proof. The result follows from the definition of IPMs. R ( f ) − R w ( f ) = E X | T [ (cid:96) f ( X ) | T = 1] − E X | T [ w ( X ) (cid:96) f ( X ) | T = 0] ≤ (cid:12)(cid:12) E X | T [ (cid:96) f ( X ) | T = 1] − E X | T [ w ( X ) (cid:96) f ( X ) | T = 0] (cid:12)(cid:12) ≤ (cid:107) (cid:96) f (cid:107) L sup h ∈L (cid:12)(cid:12) E X | T [ h ( X ) | T = 1] − E X | T [ w ( X ) h ( X ) | T = 0] (cid:12)(cid:12) (15)= (cid:107) (cid:96) f (cid:107) L IPM L ( p , p w ) . Step (15) relies on that (cid:96)/ (cid:107) (cid:96) (cid:107) L ∈ L . Further, for importance weights w IS ( x ) = p ( x ) /p ( x ),for any h ∈ L , under Assumption 2 (overlap), E X | T [ h ( X ) | T = 1] − E X | T [ w IS ( x ) h ( x ) | T = 0]= E X | T [ h ( X ) | T = 1] − E X | T (cid:20) p ( x ) p ( x ) h ( x ) | T = 0 (cid:21) = 0and the first inequality in (14) is tight, as IPM L ( p , p w ) = 0. Given that IPM ≥ p ( x ) p ( x ) is not defined on the support of p ( x ). The result for R ( f ) follows analogously. Corollary 1.
Under the conditions of Lemma 3, with ˜ w ( x ) := π t + (1 − π t ) w ( x ) , R ( f t ) ≤ R ˜ wt ( f t ) + (1 − π t ) (cid:107) (cid:96) f (cid:107) L IPM L ( p − t , p wt ) (16)13 roof. The result follows immediately from Lemma 3.
Remark 3 (Necessity of assumptions) . Lemma 3 and Corollary 1 do not strictly speakingdepend on Assumption 1 (ignorability) due to the definitions of R ( f t ) and (cid:96) f t being madew.r.t. the potential outcomes Y ( t ). However, to estimate the right-hand side of (16) fromobservational data, ignorability is required. Moreover, neither result depend on Assump-tion 2 (overlap) as long as w ( x ) is defined everywhere on p t ( x ). For particular losses, wecan avoid making assumptions about (cid:96) f t , by making assumptions on f t and the hypoth-esis class H instead. This approach was taken by Ben-David et al. (2007), who used theso-called triangle inequality for loss functions to give bounds on the risk in unsuperviseddomain adaptation under assumptions on H . However, this leads to the rather unattractiveproperty that the resulting bounds are not tight even in the special case that p = p . Adopting results from statistical learning theory, we bound the difference between empiricalestimates of R wt ( f t ) and IPM L ( p − t , p wt ) and their expected counterparts. These results arethen combined to form a bound on R ( f t ).The re-weighted risk R wt may be estimated, for a fixed weighting function w by thestandard Monte-Carlo method. Consider a sample D = { ( x , t , y ) , ..., ( x n , t n , y n ) } ∼ p n ( X, T, Y ), with n t = (cid:80) ni =1 [ t i = t ] and define the empirical weighed factual risk,ˆ R wt ( f t ) := 1 n t n (cid:88) i =1 w ( x i ) L ( f t ( x i ) , y i ) . We aim to bound the difference∆( f t ) := R wt ( f t ) − ˆ R wt ( f t ) . To achieve this, we use a result from the literature which builds on the concept of pseudo-dimension Pdim( H ) of a function class H . For brevity, we refrain from stating its fulldefinition here and refer to Pollard (2012) as a reference. Lemma 4 (Cortes et al. (2010)) . Let (cid:96) h = E Y | X [ L ( h ( x ) , Y ) | X = x ] be the expectationof the squared loss L ( y, y (cid:48) ) = ( y − y ) of a hypothesis h ∈ H ⊆ { h (cid:48) : X → R } , let d = Pdim( { (cid:96) h : h ∈ H} ) and let σ Y = E X,Y [ L ( Y, E Y | X [ Y | X ])] . Then, for a weightingfunction w ( x ) such that E X [ w ( X )] = 1 , with probability at least − δ over a sample (( x , y ) , ..., ( x n , y n )) , with empirical distribution ˆ p , R w ( h ) ≤ ˆ R w ( h ) + V p, ˆ p [ w ( x ) l ( x )] C H n n / + σ Y where C H n = 2 / (cid:18) d log 2 ned + log 4 δ (cid:19) / and V p, ˆ p [ w ( x ) l ( x )] = max( (cid:113) E X [ w ( X ) (cid:96) h ( X )] , (cid:113) E X ∼ ˆ p [ w ( X ) (cid:96) h ( X )]) . Lemma 4 applies to any valid weighting function w , not only importance weights or weightsbased on balancing scores. Used in conjunction with Corollary 1, Lemma 4 allows us toseparate the bias (the IPM-term) and variance (see above) introduced by w .14he efficiency with which a sample may be used to estimate IPM L depends on thechosen function family L . In particular, the sample complexity of learning the Wassersteindistance between two densities on X scales as O ( d ) with the dimension d of X , whereas thekernel-based MMD has O (1) dependence. Below, we state a result bounding the samplecomplexity for the MMD with universal kernels. Lemma 5 (Sriperumbudur et al. (2009)) . Let X be a measurable space. Suppose k is auniversal, measurable kernel such that sup x ∈X k ( x, x ) ≤ C ≤ ∞ and H the reproducingkernel Hilbert space induced by k , with ν := sup x ∈X ,f ∈H f ( x ) < ∞ . Then, with ˆ p, ˆ q theempirical distributions of p, q from m and n samples, and with probability at least − δ , | IPM H ( p, q ) − IPM H (ˆ p, ˆ q ) | ≤ (cid:114) ν log 4 δ C (cid:18) √ m + 1 √ n (cid:19) . The Gaussian RBF kernel k ( x, x (cid:48) ) = e −(cid:107) x − x (cid:48) (cid:107) / (2 σ ) , with bandwidth σ >
0, is an importantclass of universal kernels to which Lemma 5 applies.With Lemmas 3–5 in place, we can now state our main result.
Theorem 1.
Assume that ignorability (Assumption 1) holds w.r.t. X . Given is asample ( x , t , y ) , ... , ( x n , t n , y n ) i.i.d. ∼ p ( X, T, Y ) with empirical measure ˆ p and n t := (cid:80) ni =1 [ t i = t ] for t ∈ { , } . Let f t ( x ) ∈ H be a hypothesis of E [ Y ( t ) | X = x ] and (cid:96) f t ( x ) := E Y [ L ( f t ( x ) , Y ( t )) | X = x ] where L ( y, y (cid:48) ) = ( y − y (cid:48) ) . Assume that thereexists a constant B > such that, (cid:96) f t ( x ) /B ∈ L , where L is a reproducing kernelHilbert space of a kernel, k such that k ( x, x ) < ∞ . Finally, let w : X → R + be avalid re-weighting of p t , E [ w ( X ) | T = t ] = 1 , and let ˜ w ( x ) = π t + (1 − π t ) w ( x ) , where π t = p ( T = t ) . With probability at least − δ , R ( f t ) ≤ ˆ R ˜ wp t ( f t ) + B (1 − π t )IPM L (ˆ p − t , ˆ p w t t )+ V p t ( ˜ w, (cid:96) f t ) C H n t ,δ n / t + D L n ,n ,δ (cid:18) √ n + 1 √ n (cid:19) + σ Y ( t ) where C H n,δ is a function of the pseudo-dimension of H , D H n ,n ,δ a function of the kernelnorm of L , both only with logarithmic dependence on n and m , σ Y ( t ) is the expectedvariance in Y ( t ) , and V p ( w, (cid:96) f t ) = max (cid:16)(cid:113) E p [ ˜ w (cid:96) f t ] , (cid:113) E ˆ p [ ˜ w (cid:96) f ] (cid:17) . A similar boundexists where L is the family of functions Lipschitz constant at most 1 and IPM L theWasserstein distance, but with worse sample complexity.Proof. The result follows from application Lemmas 4–5 to Lemma 3 and is given in largergenerality in Theorem 2.We can also immediately state the following corollary.15 orollary 2.
Assume that the conditions of Theorem 1 hold. Let f ( x, t ) := f t ( x ) andlet ˆ R wp ( f ) := (cid:80) ni =1 w ( x i , t i ) L ( f ( x i , t i ) , y i ) /n represent the weighted empirical factual risk.Then, with ˜ w ( x, t ) := w t ( x ) /π t , n min = min( n , n ) and σ = max( σ Y (0) , σ Y (1) ) there is aconstant K L , H ,w,δ,n ,n with at most logarithmic dependence on n , n , such that R (ˆ τ )2 ≤ ˆ R ˜ w t p ( f ) + B [ π IPM L (ˆ p , ˆ p w ) + π IPM L (ˆ p , ˆ p w )] + K L , H ,w,δ,n ,n n / + 2 σ . A tighter result may be obtained by decomposing the constant K . Theorem 1 and Corollary 2 hint at several interesting dependencies between general-ization error, treatment group imbalance, sample re-weighting schemes and the choice ofhypothesis class. We comment on these below.
Bounds with overlap and known propensity scores.
If overlap is satisfied andpropensity scores known, applying importance weights w t ( x ) = p ( T = t | X = x ) /p ( T =1 − t | X = x ) in Theorem 1 leads to a tight bound in the limit of infinite samples (IPMand variance terms approach zero, re-weighted risk approaches desired population risk). Aspecial case of this is the randomized controlled trial (RCT), in which T has no dependenceon X . In this setting, the IPM-terms depend only on the finite-sample differences betweentreatment groups—which may still be useful to characterize. It has been shown that underoverlap, in the asymptotic limit, the best achievable sample complexity is unrelated to theimbalance of p and p (Alaa and Schaar, 2018). However, this setting is not our mainconcern as we are specification agnostic and focus on the finite-sample case. Bounds without overlap.
Theorem 1 does not rely on treatment group overlap. Instead,it relies on an assumption that the true loss (w.r.t. features X and potential outcome Y ( t ))is a function in the given family L . Additionally, the bound requires that w t ( x ) is definedeverywhere on p t ( x ) for t ∈ ,
1. In particular, if for some x ∈ X , p − t ( x ) = 0 and p t ( x ) >
0, importance weights are not defined everywhere. We return to the question ofoverlap in the next section, following Theorem 2.
Bias and variance.
The term V p t ( ˜ w, (cid:96) f t ) in Theorem 1 shows that a less uniform re-weighting w leads to larger variance (dependence on n ). However, if p and p are verydifferent, a non-uniform (balancing) w is required to ensure unbiasedness, e.g., by making p w t t = p − t . This indicates that w introduces a bias-variance trade-off on top of the onetypical for supervised learning. In particular, even if the true treatment propensity η isknown, a biased weighting scheme may lead to a smaller bound on the population riskwhen p and p are far apart. Imbalance in non-confounders.
The size of the bounds in Theorem 1 and Corollary 2clearly depends on the quality of the hypothesis f and the choice of re-weighting w . Inaddition, the IPM terms depend heavily on the input space X . In particular, if variablesincluded in X are predictive of T but not predictive of Y , e.g., if they are instrumentalvariables (Ding et al., 2017), they will contribute to the IPM term but not to the expected16isk, loosening the bound needlessly. If we can learn to ignore such information, we mayobtain a tighter bound. To this end, in the next section, we derive bounds for representationsΦ( X ) of the original feature space. Generalization under policy and domain shift.
Predicting the conditional treatmenteffect for an individual may be viewed as predicting the effect of a change in treatment policyfrom one alternative to another. This notion may be generalized further by considering theestimation of treatment effects for change in policy on a population that differs from theone learned from. Specifically, this would involve a change not only in p ( T | X ) but in p ( X ) as well. We studied this extended problem in Johansson et al. (2018), and referred tochanges in both policy p ( T | X ) and domain p ( X ) as a change in design . We do not coverthis setting in detail here. When the input space X increases in dimension, treatment groups p t ( X ) tend to growincreasingly different (D’Amour et al., 2017) and, in general, this to lead to a looser boundin Theorem 1. To some extent, this can be mitigated by appropriately chosen weights w , butthe additional finite-sample variance introduced by highly non-uniform weights may preventtightening of the bound. In this section, we introduce another tool for minimizing boundson the marginal risk in hypotheses that act on learned (potentially lower-dimensional)representations of the covariates X . This allows hypotheses to focus their attention toparticular aspects of the covariate space, ignoring others.In many applications, the input distribution p ( X ) is believed to be a low-dimensionalmanifold embedded in a high-dimensional space X , for example, the space of portraits em-bedded in the pixels of a photograph. In such settings, the best hypotheses are often simplefunctions of low-dimensional representations Φ( X ) of the input (Bengio et al., 2013). Themost famous examples of this are image and speech recognition for which representationlearning using convolutional and recurrent neural networks advanced each field tremen-dously in only a few years (LeCun et al., 2015).Let E ⊂ {X → Z} denote a family of representation functions (embeddings) of theinput space X in Z and let Φ ∈ E denote such an embedding function. Further, let G ⊆ { h : Z → Y} denote a set of hypotheses h (Φ) operating on the representation Φ andlet H be the space of all such compositions, H = { f = h ◦ Φ : h ∈ G , Φ ∈ E} . Generalizingour discussion up to this point, we consider learning of Φ( X ) from data with the goal ofminimizing the marginal risk of hypotheses h ◦ Φ.For CATE to be identifiable from observations of p (Φ( X ) , T, Y ), we need precisely thesame requirements on Φ as previously on X , ignorability and overlap , ∀ t ∈ { , } : Y ( t ) ⊥⊥ T | Φ( X ) (cid:124) (cid:123)(cid:122) (cid:125) Ignorability and ∀ z ∈ Z : p ( T = t | Φ( X ) = z ) > (cid:124) (cid:123)(cid:122) (cid:125) Overlap . (17)Verifying the assumptions in (17) for a given Φ, based on observational data alone,is impossible, just as for X . To address this, we consider learning twice-differentiable,invertible representations Φ : X → Z where Ψ :
Z → X is the inverse representation, such17hat Ψ(Φ( x )) = x for all x . For treatment groups t ∈ { , } , we let p Φ ,t ( z ) be the distributioninduced by Φ over Z , with p w Φ ,t ( z, t ) := p Φ ,t ( z ) w (Ψ( z )) its re-weighted form and ˆ p w Φ ,t its re-weighted empirical form, following our previous notation. If Φ is invertible, ignorability andoverlap in X implies ignorability and overlap in Φ( X ), as p (Φ( X ) = z ) = p ( X = Ψ( z )).Building on Section 4, we can now relate the expected margin risk R ( h ◦ Φ) to the(expected) re-weighted factual risk R w ( h ◦ Φ).
Lemma 6.
Suppose that Φ is a twice-differentiable, invertible representation, that h t (Φ) ∈ G is a hypothesis, and that f t ( x ) = h t (Φ( x )) ∈ H for t ∈ { , } . Let (cid:96) Φ ,h t ( z ) := E Y [ L ( h t ( z ) , Y ( t )) | X = Ψ( z )] be the expected pointwise loss given a representation z ,where L ( y, y (cid:48) ) = ( y − y (cid:48) ) . Let A Φ be a constant such that ∀ z ∈ Z : A Φ ≥ | J Ψ ( z ) | , where J Ψ ( z ) is the Jacobian of the representation inverse Ψ , and assume that there exists aconstant B Φ > such that, with C Φ := A Φ B Φ , (cid:96) Φ ,h t /C Φ ∈ L ⊆ { (cid:96) : Z → R + } . Finally,let π t = p ( T = t ) and w be a valid re-weighting of p Φ ,t . Then, R ( f t ) ≤ π t R wt ( f t ) + (1 − π t ) C Φ · IPM L ( p Φ , − t , p w Φ ,t ) . (18) Proof.
By (13), with (cid:96) f t ( x ) := E Y [ L ( f t ( x ) , Y ( t )) | X = x ], we have that R ( f t ) = π t R wt ( f t ) + (1 − π t ) (cid:90) x ∈X (cid:96) f t ( x ) ( p − t ( x ) − p wt ( x )) dx . Then, by the standard change of variables, assuming that Φ is invertible, we have (cid:90) x ∈X (cid:96) f t ( x ) ( p − t ( x ) − p wt ( x )) dx = (cid:90) z ∈Z (cid:96) f t (Ψ( z )) ( p − t ( z ) − p wt ( z )) | J Ψ ( z ) | dz ≤ A Φ (cid:90) z ∈Z (cid:96) Φ ,h t ( z ) ( p − t ( z ) − p wt ( z )) dz . ≤ C Φ · IPM L ( p Φ , − t , p w Φ ,t ) . Here, we have used the fact that, or invertible Φ, p ( Z = Φ( x )) = p ( X = x ). The scale of Φ and the factor C Φ . Comparing Lemma 6 (bound in representation) toLemma 3 (original space), we notice two immediate differences: the additional factor C Φ and the change from measuring distributional distance in X to doing so in Z , via Φ. Themost illustrative example for why C Φ is necessary is when Φ simply reduces the scale of X ,i.e., when Φ( x ) = x/a for a >
1. IPMs often vary with the scale of the space in which theyare applied and we could reduce the right-hand side of the bound simply by scaling down X were it not for C Φ , which counteracts this reduction. The influence of Φ on the IPM Measuring distributional distance in Φ with a fixed IPMfamily L means that we may emphasize or de-emphasize part of the covariate space, evenwhen Φ is invertible. For example, if Φ is a linear function that scales down a component X ( d ) of X significantly, and L is a family of linear functions with bounded norm, theinfluence of distributional differences in X ( d ) on the IPM is reduced.With Lemma 6 in place, we can now state the a result for the finite-sample case byfollowing the same steps as in Section 4. 18 heorem 2. Given is a sample ( x , t , y ) , ..., ( x n , t n , y n ) i.i.d. ∼ p ( X, T, Y ) with empiricalmeasure ˆ p . Assume that ignorability (Assumption 1) holds w.r.t. X . Suppose that Φ is a twice-differentiable, invertible representation, that h t (Φ) is a hypothesis on Z , and f t = h t (Φ( x )) ∈ H . Let (cid:96) Φ ,h t ( z ) := E Y [ L ( h t ( z ) , Y ( t )) | X = Ψ( z )] where L ( y, y (cid:48) ) =( y − y (cid:48) ) . Further, let A Φ be a constant such that ∀ z ∈ Z : A Φ ≥ | J Ψ ( z ) | , where J Ψ ( z ) is the Jacobian of the representation inverse Ψ , and assume that there exists a constant B Φ > such that, with C Φ := A Φ B Φ , (cid:96) Φ ,h t /C Φ ∈ L , where L is a reproducing kernelHilbert space of a kernel, k such that k ( x, x ) < ∞ . Finally, let w be a valid re-weightingof p Φ ,t . Then, with probability at least − δ , R − t ( f t ) ≤ ˆ R wt ( f t ) + C Φ · IPM H (ˆ p Φ , − t , ˆ p w Φ ,t )+ V p t ( w, (cid:96) f t ) C H n t ,δ n / t + D Φ , H n ,n ,δ (cid:18) √ n + 1 √ n (cid:19) + σ Y ( t ) where C H n,δ is a function of the pseudo-dimension of H , D H n ,n ,δ a function of the kernelnorm of L , both only with logarithmic dependence on n and m , σ Y ( t ) is the expectedvariance in Y ( t ) , and V p ( w, (cid:96) f ) = max (cid:16)(cid:113) E p [ w (cid:96) f ] , (cid:113) E ˆ p [ w (cid:96) f ] (cid:17) . A similar boundexists where L is the family of functions Lipschitz constant at most 1 and IPM L theWasserstein distance, but with worse sample complexity.Proof. The result follows from application Lemmas 4–5 to Lemma 6.
Overlap, ignorability and invertibility.
Theorem 2 holds both with and without treat-ment group overlap in X . It is important to note, however, that when we change the covari-ate space from X to Φ, the assumption that (cid:96) Φ ,h t /C Φ ∈ L is not guaranteed, even for large C Φ , since information on which (cid:96) depends may have been (approximately) removed. Inthe context of risk minimization, information is only excluded from Φ if it is not predictiveof the outcome Y , in which case it is independent also of (cid:96) . Thus, under the additionalassumption of overlap, the assumption that (cid:96) Φ ,h t /C Φ ∈ L is verifiable in the limit of in-finite data. In Johansson et al. (2019), we expand on the effects of non-invertibility onidentifiability of the marginal risk in much greater detail. In particular, we show that fornon-invertible Φ, without overlap, the marginal risk may be bounded under the assumptionthat information removed in Φ is as important to the risk of the factual outcome as to thatof the counterfactual. This assumption, however, is also unverifiable in general. Connections between the problem of estimating causal effects and learning under distribu-tional shift have been pointed out in several contexts (Tian and Pearl, 2001; Zhang et al.,2013). In particular, Johansson et al. (2016) showed that estimating counterfactual out-comes under ignorability is mathematically equivalent to unsupervised domain adaptationbetween domains D ∈ { , } under covariate shift. We make this connection precise below.19 ask Data Goal Assumption Causal estimation Factual Counterfactual Ignorability( x, t, y ) ∼ p ( X, T, Y ) p ( Y (1 − T ) | X, T ) Y ( t ) ⊥⊥ T | X Domain adaptation Source domain Target label Covariate shift( x, y ) ∼ p ( X, Y | D = 0) p ( Y | X, D = 1) Y ⊥⊥ D | X The bounds we present in this work are related to a series of work on generalization the-ory for unsupervised domain adaptation (Ben-David et al., 2007; Mansour et al., 2009; Longet al., 2015), but differ in significant ways. Superficially, the bounds given in these papershave a similar form using the sum of observed risk in the source domain and distributionaldistance w.r.t. a function class H to bound the risk in the target domain: R D =1 ( f ) ≤ R D =0 ( f ) + d H ( p ( X | D = 1) , p ( X | D = 0)) + λ H . Similarly, these bounds do not rely on overlap but cannot guarantee consistent estimationin the general case. In fact, because they do not allow for re-weighting of domains, evenwhen source and target domains completely overlap, these bounds are often unnecessarilyloose. Furthermore, while they are used to motivate representation learning algorithms,these bounds do not apply to learned representations without modification (Johanssonet al., 2019). In this work, we overcome this issue by requiring that representations Φ areinvertible.
In this section, we derive learning objectives for estimating potential outcomes and CATEgrounded in the theoretical results of Sections 4–5. A downside of separately estimating thetwo potential outcomes and subtract these to obtain an estimate of the treatment effect—aso-called
T-learner (K¨unzel et al., 2017)—is that the two estimators share no informationand may sustain compounding error if the biases of the estimators are opposing. In thiswork, we use representation learning to improve on T-learning estimators in two ways: a)by allowing estimators to share information through representation functions learned fromboth treatment groups, b) by regularizing treatment group distance in representations toenable better counterfactual generalization, as motivated by Section 5.
Let D = { ( x , t , y ) , ..., ( x n , t n , y n ) } be a set of samples drawn iid from p ( X, T, Y ) and leteach sample i be endowed with a weight w i = w ( x i , t i ) for some function w : X × { , } → R + . Further, let λ, α > R , and distributional distance IPM L respectively.Recall that ˆ p w Φ ,t is the re-weighted factual distribution of representations Φ under p t . Now,we consider compositions ( h ◦ Φ) ∈ F of hypotheses h ( x, t ) of potential outcomes Y ( t ), such20hat h ∈ G ⊆ {Z × { , } → Y} , and representations Φ ∈ E ⊆ {X → Z} . Then, directlymotivated by Theorem 2, we propose to minimize the following learning objective. O ( h, Φ , λ, α ) = n (cid:88) i =1 w i n L ( h (Φ( x i ) , t i ) , y i ) (cid:124) (cid:123)(cid:122) (cid:125) Empirical (weighted) risk + λ √ n R ( h ) (cid:124) (cid:123)(cid:122) (cid:125) Regularization + α IPM L (ˆ p w Φ , , ˆ p w Φ , ) (cid:124) (cid:123)(cid:122) (cid:125) Distributional distance . (19)Under Assumptions 1 (ignorability) and 2 (overlap), for balancing weights w i = p ( T = t i ) /p ( T = t i | X = x i ), objective (19) reduces to inverse propensity-weighted regression inthe limit of infinite samples (Freedman and Berk, 2008). In the finite-sample regime, theIPM does not vanish even if p = p , because of sample variance. As pointed out previously,the objective remains an upper bound on the CATE risk for other weights. Thus, in additionto learning representations and hypotheses, we may consider learning the sample weights w jointly, controlling the variance introduced by non-uniform weights by regularizing thenorm of w Johansson et al. (2018). With β = ( α, λ h , λ w ) a set of hyperparameters, O ( h, Φ , w ; β ) = n (cid:88) i =1 w i n L ( h (Φ( x i ) , t i ) , y i ) + λ h √ n R ( h ) (cid:124) (cid:123)(cid:122) (cid:125) O h ( h, Φ ,w ; D,α,λ h ) + α IPM L (ˆ p w Φ , , ˆ p w Φ , ) + λ w (cid:107) w (cid:107) n (cid:124) (cid:123)(cid:122) (cid:125) O w (Φ ,w ; D,α,λ w ) (20)A theoretical advantage of learning also the sample weights is that it allows for an explicittradeoff between bias and variance induced by the sample weights. In addition, it allows theweights to be defined in terms of the learned representation. We proceed to give conditionsunder which minimization of (19) leads to consistent estimation of CATE. Theorem 3.
Suppose F is a reproducing kernel Hilbert space given by a bounded kernel k . Suppose weak overlap holds in that ∀ t ∈ { , } : E X [( p t ( X ) /p − t ( X )) ] < ∞ . Then,with O the objective defined in (20) , and n t = (cid:80) ni =1 [ t i = t ] for t ∈ { , } , min h, Φ ,w O ( h, Φ , w ; β ) ≤ min f ∈F R ( f ) + O p (1 / √ n + 1 / √ n ) . Consequently, under the assumptions of Thm. 2, for sufficiently large α and λ w , with ˆ f n the minimizer of (20) for n samples, R ( ˆ f n ) ≤ min f ∈F R ( f ) + O p (1 /n / + 1 /n / ) . In words, the minimizers of (20) converge to the representation and hypothesis thatminimize the counterfactual risk, in the limit of infinite samples.
Theorem 3 may be generalized further to the case of estimating the value of an abritrarytreatment policy under a shift in marginal distribution p ( X ). See Appendix A for a proofof Theorem 3 in this general setting. 21 𝑋 𝑓 $ 𝑓 % 𝑅 % Covariates Predicted potential outcomesGroup-conditional risk
Estimator for control groupEstimator for treated 𝑅 $ (a) T-learner 𝑋 Φ … …… 𝑓 & 𝑓 ’ 𝑅 𝑇 𝑌
Covariates Shared representationPredicted potential outcomesEmpirical risk OutcomeInterventionNeural network layers (b) TARNet (Shalit et al., 2017)
Figure 3: Estimator architectures for potential outcomes and conditional average treat-ment effects. Green boxes indicate inputs, white boxes outputs and loss terms, yellowboxes shared representations and blue/red boxes estimators of potential outcomes. Solidlines indicate transformation part of the prediction function and dashed lines indicate com-putations part of the learning procedure.
Objectives (19) and (20) may be used to learn or select a representation Φ which tradesoff treatment-group invariance and empirical risk. The two arguably most prominent ap-proaches to representation learning in the literature are based on a) neural networks (Bengioet al., 2013) or b) variable selection (Schneeweiss et al., 2009). As the latter does not satisfyour assumption of invertibility, and may be viewed as a subset of the former, we restrictour attention to parameterizations of Φ as neural networks. As this choice leaves a lot offreedom in the design of estimators, we discuss alternatives from the literature below.Described in Section 3, T-learner estimators fit potential outcomes entirely indepen-dently. These may be viewed as operating in the representation space of the identitytransform, Φ( X ) = X . While this does not allow for minimization of treatment groupvariance, other than through re-weighting, T-learners serve as a natural baselines for otherarchitectures. A natural extension was proposed in the Treatment-Agnostic RepresentationNetwork (TARNet) by Shalit et al. (2017). In TARNet, a T-learner architecture is appendedto a representation Φ shared between treatment groups (see Figure 3 for a comparison).TARNet has the advantage of sharing samples between treatment groups in learning therepresentation which may be useful when τ is a simpler function of X than Y (0) , Y (1). In Section 4, we bound the generalization error in CATE using integral probability metrics(IPM), a family of distances between distributions p, q based on the density difference p − q .The idea of regularizing models to be invariant to changes in a variable, e.g., the treat-ment indicator, is prevalent through-out machine learning (Ganin et al., 2016; Goodfellowet al., 2014; Long et al., 2015). As a result, several families of distance metrics betweendistributions have been used to impose such constraints. The most common of these are f -divergences (e.g. the KL-divergence) (Nowozin et al., 2016), integral probability metrics(e.g., the maximum-mean discrepancy) and adversarial discriminators (Ganin et al., 2016).22 𝑋 Φ … …… 𝑓 & 𝑓 ’ 𝑅 𝑇 𝑌
Covariates Shared representationPredicted potential outcomes RiskInterventionOutcomeNeural network layers 𝑑 𝑝 -,& , 𝑝 -,’
Treatment group distance
Figure 4: Illustration of the Counterfactual Regression (CFR) estimator. Here, d representsa distributional distance such as an IPM. The visual elements are described in Figure 3. f -divergences are often ill-suited for comparing two empirical densities as they are based onthe density ratio which is undefined in any point outside of the support of either density. Incontrast, IPMs are based on the density difference which is defined everywhere. Adversar-ial methods are based on the metric implied by a learned discriminator function which istrained to distinguish samples from the two densities. The flexibility of this approach—thatit tailors the metric to observed data—is also its weakness since optimization of adversarialdiscriminators is fraught with difficulty.The TARNet architecture described above is well-suited for incorporating regularizationon the distributional distance in Φ according to Objective (19). In particular, we use theempirical kernel MMD (Gretton et al., 2012) and the Wasserstein distance (Villani, 2008)for this purpose. We dubbed the resulting estimator Counterfactual Regression (CFR) inShalit et al. (2017) (see Figure 4). In Johansson et al. (2018), we derived a further extension,incorporating a learned sample re-weighting function minimizing (20), called Re-weightedCFR (RCFR) and illustrated in Appendix C Minimizing the maximum mean discrepancy.
The maximum mean discrepancy(MMD) was popularized in machine learning through its kernel-based incarnation in whichthe associated function family is a reproducing kernel Hilberg space (RKHS) (Gretton et al.,2012). We restrict our attention to this family here. An unbiased estimator of the MMDdistance between densities p, q on X , with respect to a kernel k , may be obtained fromsamples x , ..., x m ∼ p , x (cid:48) , ..., x (cid:48) n ∼ q as follows.ˆMMD k ( p, q ) := 1 m − m (cid:88) i =1 m (cid:88) j =1 k ( x i , x j ) − mn m (cid:88) i =1 n (cid:88) j =1 k ( x i , x (cid:48) j ) + 1 n − m (cid:88) i =1 m (cid:88) j =1 k ( x (cid:48) i , x (cid:48) j ) . By choosing a differentiable kernel k , such as the Gaussian RBF-kernel, we can ensure thatthe MMD is amenable to gradient-based learning. In applications where the quadratic timecomplexity w.r.t. sample size is prohibitively large, another unbiased estimator (but withlarger variance) may be obtained by sampling pairs of points ( x , x (cid:48) ) , ..., ( x n , x (cid:48) n ) ∼ p × q k ( p, q ) := 1 n n (cid:88) i =1 (cid:2) k ( x i − , x i ) + k ( x (cid:48) i − , x (cid:48) i ) − k ( x i − , x (cid:48) i ) − k ( x i , x (cid:48) i − ) (cid:3) . Minimizing the Wasserstein distance.
The Wasserstein distance is typically com-puted as the solution to a linear program (LP). The gradient of this solution with respectto the learned representation may be obtained through the KKT conditions of the problemand the solution for the current representation (Amos and Kolter, 2017). However, solvingthe LP at each gradient update is prohibitively expensive for many applications. Instead,we minimize an approximation of the distance known as Sinkhorn distances (Cuturi, 2013),computed using fixed-point iteration. In previous work (Shalit et al., 2017), we computedthe distance and its gradient by forward and backpropagation through a recurrent neuralnetwork with transition matrix corresponding to the fixed-point update. For a full descrip-tion, see Appendix D. Alternative methods for minimizing Wasserstein distances have beendeveloped in the context of generative adversarial networks (GANs) (Arjovsky et al., 2017).
Evaluating estimates of potential outcomes and causal effects from observational data is no-toriously difficult as ground-truth labels are hard or impossible to come by. Cross-validationand other sample splitting schemes frequently used to evaluate supervised learning are notimmediately applicable to our setting for this reason. Moreover, the task of producingthe labels themselves is exactly the task we are attempting to solve. As a result, esti-mation methods are often evaluated on synthetic or semi-synthetic data, where consistentestimation or computation of the labels are guaranteed. Another alternative is using real-world data where the treatment-assignment randomization is known, e.g. data from anRCT. In this section, we give a suite of experimental results on synthetic, semi-syntheticand real-world data. Our experiments are developed to separately highlight the impact ofarchitecture choice and the balancing regularization scheme.
Recall that we refer to our algorithms, minimizing the objectives in (19) or (20), as Coun-terfactual Regression (CFR). The version of CFR with penalty α = 0 is referred to asTreatment Agnostic Representation Network (TARNet). We specify the function familyused in the IPM by a subscript, e.g., CFR MMD , and point out for which experiments theweighting function is learned and for which it is set to the uniform weighting. All variantsof CFR were implemented as feed-forward neural networks with exponential-linear unitsand architectures as described in Section 6. Ranges for hyperparameters, such as layersizes, learning rates et cetera, are described in Appendix B and specific values were selectedaccording to a procedure below. An implementation of CFR with uniform sample weightsmay be found at https://github.com/clinicalml/cfrnet .As our primary baseline, we use two variants of Ordinary Least Squares (linear regres-sion). The first (
OLS-S ) adopts the S-learner paradigm and includes the treatment variable24 as a feature in the regression. The second (
OLS-T ) is a T-learner where the outcome ineach treatment arm is modeled using a separate linear regression. Our other simple baselineis a k -nearest neigbor regression which imputes counterfactual outcomes of a unit by theaverage of its k -nearest neighbors with the opposite treatment assignment.For a more challenging comparison, we use Targeted Maximum Likelihood, which is adoubly robust method ( TMLE ) (Gruber and van der Laan, 2011) which uses an ensemble ofmachine-learning methods. We also compare with a suite of tree-based estimators: First, weuse a Random Forest (
Rand. For. ) (Breiman, 2001) in the S-learner paradigm by including T as a feature. Second, we include tree-based methods specifically designed or adaptedfor causal effect estimation: Bayesian Additive Regression Trees ( BART ) (Chipman et al.,2010; Chipman and McCulloch, 2016) and Causal Forests (
Caus. For. ) (Wager and Athey,2015; Athey, 2016). Finally, we also compare with our earlier work on Balancing LinearRegression (
BLR ) and Balancing Neural Network (
BNN ) (Johansson et al., 2016).
Evalutation criteria & hyperparameter selection
To assess the quality of CATE estimates, either knowledge of the propensity score or theoutcome function is required. Where labels are available, our primary criterion for evalu-ation is the mean squared error in the imputed CATE as defined in (6). When only thepropensity score is available, such as in a randomized controlled trial or other experiments,we instead estimate the policy risk as defined below.A policy π is any (possibly stochastic) function that maps from covariates x to treatmentdecision t ∈ { , } ; we will only consider deterministic policies. The risk of a policy π foroutcomes Y ∈ [0 , Y is considered beneficial, is R Pol ( π ) := 1 − E X [ E ( Y (0) ,Y (1) [ Y ( π ( x )) | X = x ]] . A good policy is one that for a given x will choose the potential outcome with the higherconditional expectation given x . If we know the true propensity scores p ∗ ( t i = 1 | x i ) usedin generating the dataset, then the risk of a deterministic policy R pol ( π ) may be estimatedusing rejection sampling based on a sample ( x , t , y ) , . . . , ( x m , t m , y m ) and propensityscores p ∗ ( t = 1 | x ) , . . . , p ∗ ( t m = 1 | x m ) by considering only the propensity re-weightedeffective sample on which the proposed policy agrees with the observed one:ˆ R Pol ( π ) := 1 − (cid:80) mi =1 y i [ π ( x i ) = t i ] p ∗ ( t i = π ( x i ) | x i ) (cid:80) mi =1 [ π ( x i ) = t i ] . (21)A downside of this estimator is that it has very high variance for policies that are verydifferent from the observed policy. Note that in the case where the data was generatedby an RCT with equal probability of treatment and control, the propensity scores have aparticularly simple form: p ( t = 1 | x ) = 0 . x .In our experiments, we evaluate the policy π f : X → T induced by an estimator f ( x, t )of potential outcomes and a threshold λ such that π f ( x ) := (cid:26) , if f ( x, − f ( x, > λ , otherwise . (22)25y varying λ from low to high we obtain a curve that interpolates between liberal andconservative allocation of treatment.In all experiments we fit a model on a training set and then evaluate on a held-out set.We always report results both within-sample and out-of-sample. We wish to emphasizethat the within-sample results should not be thought of as training-loss in standard MLproblems. Even within-sample results include the challenging task of inferring unobservedcounterfactuals for the training samples. Hyperparameter selection.
We choose hyperparameters for all estimators in the sameway. As the ground truth potential outcomes are unavailable to us, we use pseudo-labelsfor the true CATE imputed using a nearest-neighbor estimator. With j ( i ) the nearest“counterfactual” neighbor of sample i in Euclidean distance, such that t j ( i ) (cid:54) = t i , we defineˆMSE nn ( f ) := 1 n n (cid:88) i =1 (cid:0) (1 − t i )( y j ( i ) − y i ) − ( f ( x i , − f ( x i , (cid:1) and use its value on a held-out validation set as a surrogate for the true MSE in ˆ τ in hyper-parameter section. This choice may bias selection of hyperparameters towards preferringmodels close to a nearest-neighbor estimator, but we anticipate this effect to be mild asˆMSE nn is not used as a training objective. For neural network estimators, we perform earlystopping based on the training objective evaluated on a held-out validation set in the IHDPstudy, and based on held-out policy risk in the Jobs study (both described below). Rangesfor hyperparameters for CFR are presented in Appendix B. The Infant Health and Development Program (IHDP) dataset has been frequently usedto evaluate machine learning approaches to causal effect estimation in recent years (Hill,2011). The orginal data comes from a randomized study of the impact on educational andfollow-up interventions on child cognitive development (Brooks-Gunn et al., 1992). Eachobservation represents a single child in terms of 25 features of their birth and their mothers.To introduce confounding, Hill (2011) removed a biased subset of the treatment group—alltreated children with nonwhite mothers—leaving 747 subjects in total. This induces notonly confounding, but also lack of overlap in variables strongly correlated with race (raceitself was removed from the feature set following the biased selection). To enable consistentevaluation, the outcome of the IHDP dataset was synthesized according to several differentstochastic models on the observed feature set. In this way, ignorability is guaranteed, butoverlap is violated by design. Depending on the specific sample of the outcome model,i.e., whether variables correlated with race have strong influence or not, the lack of overlapvaries in its impact on the results.In our experiments, we use observations generated using setting “A” in the NPCI pack-age (Dorie, 2016), corresponding to response surface (outcome function) “B” in Hill (2011).This model follows an exponential-linear form for the outcome under treatment and a linearform for the controls, ensuring that their difference, CATE, is a nonlinear function. Spar-sity in the coefficients is introduced through random sampling with a probability 0.6 that26able 1: Mean squared error, and standard error over 1000 random draws of the outcomemodel, in estimates of CATE and ATE on IHDP within-sample (left) and out-of-sample(right). Lower is better. † Not applicable.
Within sample Out of sample mse cate mse ate mse cate mse ate
OLS-S 5 . ± . . ± .
04 5 . ± . . ± . . ± . . ± .
01 2 . ± . . ± . . ± . . ± .
04 5 . ± . . ± . k -NN 2 . ± . . ± .
01 4 . ± . . ± . . ± . . ± . † † BART 2 . ± . . ± .
01 2 . ± . . ± . . ± . . ± .
05 6 . ± . . ± . . ± . . ± .
01 3 . ± . . ± . . ± . . ± .
03 2 . ± . . ± . . ± .
02 0 . ± .
01 0 . ± .
02 0 . ± . . ± .
01 0 . ± .
01 0 . ± .
02 0 . ± . . ± .
02 0 . ± .
01 0 . ± .
02 0 . ± .
01a coefficient is exactly equal to 0. The full description of the model may be found in Hill(2011), Section 4.1. The specific realizations (draws) used in our evaluation can be accessedat . Results.
The error in estimates of CATE on IHDP can be seen in Table 1. Here, wepresent only the variants of CFRwith uniform weighting, and refer to Figure 5b for a com-parison between learned and uniform weights. First, we note that all of the proposed neuralnetwork estimators (TARNet and CFR variants) outperform the selected baselines. We at-tribute this, to a large extent, to multi-layer neural networks being a suitable function classfor this dataset. CFR improves marginally over TARNet, indicating that regularizing dis-tributional invariance is beneficial for prediction of CATE. We note also that, in general, theS-learner estimators (OLS-S and BNN) perform worse than separate or partially separateestimators (OLS-T, TARNet). The biggest differences between in-sample and out-of-sampleperformance are attained by the k -NN and random forest estimators. Increasing imbalance.
In Figure 5a, we study the effect of increasing the imbalance be-tween treatment groups through biased subsampling. To do this, we fit a logistic regressionpropensity score model ˆ p ( T = 1 | X = x ) and for a parameter q ≥
0, we repeatedly removethe control sample with largest estimated propensity with probability, q and a random con-trol observation with probability 1 − q , until 400 samples remain. For three values of q , weestimate CATE using CFR with uniform sample weights for different values of the penalty α of treatment group distance in the learned representation Φ. We see that for small α ,27 -5 -4 -3 -2 -1 ymbalatce petalty , α ϵ Pe H e r e l a t i v e t o α = q = 0 . q = 0 . q = 1 . R M S E ( 𝜏 ̂ ) r e l a ti v e t o 𝛼 = (a) Ratio of mean squared squared error in esti-mated CATE relative to α = 0, as a function of theimbalance regularization α with uniform sampleweights w , for different levels of introduced addi-tional treatment group imbalance q . Uncertaintybands show standard errors over 500 realizations. − − − ∞ Re-weighting regularization λ w (uniformity) . . . . . C A T EE rr o r , R M S E ( ˆ τ ) α =0.0001 α =0.01 α =1 α =10 α =100 α =1000 (b) RMSE for CFR estimates combining rep-resentation learning and re-weighting, mini-mizing (20), for varying weight regularization λ w and imbalance penalty α . Higher λ w leadsto more uniform weights. Figure 5: Results for estimating CATE on IHDP with different variants of the CFR model.In (a), we show results for the best performing architecture with uniform sample weights,varying the imbalance regularization α . In (b), we show the results for a smaller architectureand their dependence on the uniformity of learned weights.as expected, the relative error is comparable to TARNet ( α = 0), but that it decreasesuntil α ≈
1. For α >
2, the performance deteriorates as the influence of the input on therepresentation is constrained too heavily. As we’ll see below, this may be partially remediedby sample weighting.
Learning the sample weights.
In Figure 5b, we study the quality of CFR estimateswhen sample weights are learned by minimizing objective (20). In this setting, the chosenmodel is intentionally restricted to have representations Φ of two layers with 32 and 16hidden units each and hypotheses h (Φ) of a single layer with 16 units. This choice wasmade to emphasize the value of reweighting under model misspecification. The weightingfunction was modeled using two layers of 32 units each. We see in Figure 5b that a modelusing non-uniform sample weights ( λ w small) is less sensitive to excessively large penalties α . This is because the IPM term may now be minimized also by learning the weights, ratherthan only by constraining the capacity of Φ. In the small- α regime, the non-uniformity ofweights has almost no impact, as the incentive to reduce the IPM using the weights is toosmall. In this experiment, the best results are attained for combination of a considerablylarger value of α and small penalty on the non-uniformity of weights. In general, we do notobserve any adverse effects of having a small value of λ w . This is likely due to the choice ofarchitecture for the learned weighting function already constraining the weights.28 .0 0.2 0.4 0.6 0.8 1.0 Trestmett itcluRiBt rste O u t B f R s m p l e p B li c k r i R k BAiTCsuRslFBreRtRCFiMMDTAiNetistdBmpBlick O u t o f s a m p l e po li c y r i s k , 𝑅 " $ % ( 𝜋 ) Figure 6: Policy risk as a function of treatment inclusion rate on Jobs. Lower is better.Subjects are included in treatment in order of their estimated treatment effect given by thevarious methods. CFR Wass is similar to CFR MMD and is omitted to avoid clutter
LaLonde (1986) carried out a widely known experimental study of the effect of job trainingon future income and employment status based on the National Supported Work (NSW)program. Later, Smith and Todd (2005) combined the LaLonde study with observationaldata to form a larger dataset which has been used frequently as a benchmark in the causalinference community. The presence of the randomized subgroup allows for straightforwardestimation of average treatment effects and policy value.The original study by Smith and Todd (2005) includes 8 covariates such as age andeducation, as well as previous earnings. The treatment indicates participation in the NSWjob training program. By construction, all treated subjects belong to the LaLonde experi-mental cohort; the observational cohort includes only controls. Additionally, the nature ofthe observational cohort is such that overlap is minimal at best—the experimental cohortmay be separated from the observational using a linear classifier with 96% accuracy. Thismeans that global estimators of the control outcome applied to the treated, such as linearmodels or difference-in-means estimators of causal effects are likely to suffer severe bias ifnot re-weighted.Based on the original outcome measuring yearly earnings at the end of the study, weconstruct a binary classification task called Jobs, in which the goal is to predict unemploy-ment. Following Dehejia and Wahba (2002), we use an expanded feature set that introducesinteraction terms between some of the covariates. The task is based on the cohort used bySmith and Todd (2005) which combines the LaLonde experimental sample (297 treated,425 control) and the “PSID” comparison group (2490 control). There were 482 (15%) sub-jects unemployed by the end of the study. In our experiments, we average results over 10train/validation/test splits of the full cohort with ratios 56/24/20. We train CFR methodswith uniform weighting, according to (19), selecting the imbalance parameter α accordingto held-out policy risk. 29able 2: Policy risk and mean squared error in estimates of ATT on Jobs within-sample(left) and out-of-sample (right). Lower is better. † Not applicable.
Within sample Out of sample ˆ R Pol mse att ˆ R Pol mse att
LR-S 0 . ± .
00 0 . ± .
00 0 . ± .
02 0 . ± . . ± .
00 0 . ± .
01 0 . ± .
01 0 . ± . . ± .
01 0 . ± .
01 0 . ± .
02 0 . ± . k -NN 0 . ± .
00 0 . ± .
01 0 . ± .
02 0 . ± . . ± .
00 0 . ± . † † BART 0 . ± .
00 0 . ± .
00 0 . ± .
02 0 . ± . . ± .
01 0 . ± .
01 0 . ± .
02 0 . ± . . ± .
00 0 . ± .
01 0 . ± .
02 0 . ± . . ± .
01 0 . ± .
01 0 . ± .
02 0 . ± . . ± .
01 0 . ± .
02 0 . ± .
01 0 . ± . . ± .
00 0 . ± .
01 0 . ± .
01 0 . ± . . ± .
01 0 . ± .
01 0 . ± .
01 0 . ± . Results.
In Table 2, we give the policy risk ˆ R Pol evaluated over the randomized componentof Jobs, as defined in (21), and mean squared error in the estimated average treatment effecton the treated ( mse att ). The policy we consider in the table assigns treatment to thetop subject for which the CATE is estimated to be positive ( λ = 0). We can see fromthe results that, despite the significant lack of overlap, the difference between linear andnon-linear estimators is much less pronounced than for IHDP. This is likely partly due tothe features used in the jobs dataset which have been handcrafted to predict the outcomeof interest well. In contrast, the IHDP outcome is non-linear by construction.We also see that straightforward logistic regression does remarkably well in estimatingthe ATT. However, being a linear model, logistic regression can only ascribe a uniform policy– in this case, “treat everyone”. The more nuanced policies offered by non-linear methodsachieve lower policy risk, though this difference is less pronounced in the out of samplecase, indicating that part of the difference may be due to overfitting. The nearest-neighborestimator k -NN appears to perform incredibly well within-sample, but generalizes poorlyto the hold-out. Additionally, its estimate of the ATT is the worst among the baselines.In Figure 6, we plot policy risk as a function of treatment threshold λ , as defined in (22).This is described in the figure as varying the fraction of subjects treated in a policy thattreats only the subjects with the largest estimated CATE. Overall, the benefits of imbalanceregularization of the CFR models offer less advantage than on IHDP. This may be due tothe smaller covariate set of Jobs containing less redundant features than those in IHDP.Recall that the IHDP outcome coefficients have 60% sparsity in the feature set, by design.In contrast, the Jobs covariate set has been hand-picked to account for confounding. Thismeans that one of the benefits of imbalance-regularizations of representations—to excludevariables only predictive of treatment—is likely to have a smaller effect in comparison.30 Discussion
We have presented generalization bounds for estimation of potential outcomes and causaleffects from observational data. These bounds were used to derive learning objectives forestimation algorithms that proved successful in empirical evaluation. The bounds do notrely on the so-called treatment group overlap (or positivity) assumption, common to moststudies of causal effects from observational data. This assumptions states that for anyone observed subject, there is some probability that they were prescribed either treatmentoption. Removing this assumption means that we cannot identify the causal effect non-parametrically but, as we show in this work, we can still bound the expected error (risk) ofany hypothesis in a given class.When can we expect overlap to not hold yet identification to be possible? One exampleis when many of the covariates in the conditioning set X have a strong effect on treatmentbut only a weak or non-existent effect on the outcome. For example, if some of the covariatesin X are actually instrumental variables, conditioning on them might substantially increasevariance and reduce overlap, with no gain in estimating the CATE function (Brookhartet al., 2010; Shortreed and Ertefaie, 2017). We conjecture that this might often be the casein high-dimensional cases: in aggregate, there might not be nominal overlap with respectto the measured covariates, while at the same time many of them are actually only weakconfounders, or even not confounders at all; see also D’Amour et al. (2017).Our results offer several new perspectives on causal effect estimation. In particular, theybring together two hitherto separate approaches to dealing with treatment group shift—representation learning and sample re-weighting—and give insight into when either approachis likely to be more successful and when they should be used together. It is well knownthat under the overlap and ignorability assumptions, ordinary risk minimization leads toconsistent estimation of causal effects (Pearl, 2009; Ben-David and Urner, 2012; Alaa andSchaar, 2018) in the limit of infinite samples, but the hardness of the problem is less wellunderstood in the finite sample case, or when overlap is violated. Our results provide someinsight in this setting.It is customary in machine learning to evaluate methodological progress based on per-formance on a small number of benchmarks, such as MNIST (LeCun et al., 1998) or Im-ageNet (Deng et al., 2009). Similarly, IHDP has become a de facto benchmark for causaleffect estimation (Hill, 2011; Shalit et al., 2017; Alaa and Schaar, 2018; Shi et al., 2019).However, IHDP is smaller than most machine learning benchmarks and even more suscep-tible to “test set overfitting”. Even disregarding the size discrepancy it may be argued thatbenchmarks for causal effect estimation suffer even more from going stale as the strong as-sumptions we make (or synthesize) need not hold in the tasks we wish to apply our modelsto. Moreover, the relatively simple form of the outcome model, the small dimensionality,and the structured fully observed nature of the data makes IHDP a much easier challengethan what we may face in for example analysis of electronic healthcare records. Towardsunderstanding the behavior of different estimators, a dataset like IHDP provides but a sin-gle sample of the problems we may encounter in applications. We believe it is of utmostimportance for the field as a whole to produce a larger set of benchmarks that reflect thediversity of real-world observational studies, and that the recent ACIC challenge (Shimoniet al., 2018) is a good step in this direction. 31 cknowledgments We thank Rajesh Ranganath, Alexander DAmour, Zach Lipton, Jennifer Hill, Rahul Kr-ishnan, Michael Oberst, Hunter Lang and Christina X Ji for insightful feedback and discus-sions. The preparation of this manuscript was supported in part by Office of Naval ResearchAward No. N00014-17-1-2791 and the MIT-IBM Watson AI Lab.
References
Abrevaya, J., Hsu, Y.-C., and Lieli, R. P. (2015). Estimating conditional average treatmenteffects.
Journal of Business & Economic Statistics , 33(4):485–505.Alaa, A. and Schaar, M. (2018). Limits of estimating heterogeneous treatment effects:Guidelines for practical algorithm design. In
International Conference on Machine Learn-ing , pages 129–138.Amos, B. and Kolter, J. Z. (2017). Optnet: Differentiable optimization as a layer in neuralnetworks. In
Proceedings of the 34th International Conference on Machine Learning-Volume 70 , pages 136–145. JMLR. org.Arjovsky, M., Chintala, S., and Bottou, L. (2017). Wasserstein gan. arXiv preprintarXiv:1701.07875 .Athey, S. (2016). causalTree. https://github.com/susanathey/causalTree .Athey, S. and Imbens, G. (2016). Recursive partitioning for heterogeneous causal effects.
Proceedings of the National Academy of Sciences , 113(27):7353–7360.Aude, G., Cuturi, M., Peyr´e, G., and Bach, F. (2016). Stochastic optimization for large-scaleoptimal transport. arXiv preprint arXiv:1605.08527 .Austin, P. C. (2011). An introduction to propensity score methods for reducing the effectsof confounding in observational studies.
Multivariate behavioral research , 46(3):399–424.Belloni, A., Chernozhukov, V., and Hansen, C. (2014). Inference on treatment effects afterselection among high-dimensional controls.
The Review of Economic Studies , 81(2):608–650.Ben-David, S., Blitzer, J., Crammer, K., and Pereira, F. (2007). Analysis of representationsfor domain adaptation. In
Advances in neural information processing systems , pages137–144.Ben-David, S. and Urner, R. (2012). On the hardness of domain adaptation and the utility ofunlabeled target samples. In
International Conference on Algorithmic Learning Theory ,pages 139–153. Springer.Bengio, Y., Courville, A., and Vincent, P. (2013). Representation learning: A reviewand new perspectives.
IEEE transactions on pattern analysis and machine intelligence ,35(8):1798–1828. 32ertsimas, D., Kallus, N., Weinstein, A. M., and Zhuo, Y. D. (2017). Personalized diabetesmanagement using electronic medical records.
Diabetes care , 40(2):210–217.Breiman, L. (2001). Random forests.
Machine learning , 45(1):5–32.Brookhart, M. A., St¨urmer, T., Glynn, R. J., Rassen, J., and Schneeweiss, S. (2010). Con-founding control in healthcare database research: challenges and potential approaches.
Medical care , 48(6 0):S114.Brooks-Gunn, J., Liaw, F.-r., and Klebanov, P. K. (1992). Effects of early interventionon cognitive function of low birth weight preterm infants.
The Journal of pediatrics ,120(3):350–359.Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W., Robins,J., et al. (2017). Double/debiased machine learning for treatment and causal parameters.Technical report.Chipman, H. and McCulloch, R. (2016). BayesTree: Bayesian Additive Regression Trees. https://cran.r-project.org/web/packages/BayesTree .Chipman, H. A., George, E. I., McCulloch, R. E., et al. (2010). Bart: Bayesian additiveregression trees.
The Annals of Applied Statistics , 4(1):266–298.Cortes, C., Mansour, Y., and Mohri, M. (2010). Learning bounds for importance weighting.In
Advances in neural information processing systems , pages 442–450.Cuturi, M. (2013). Sinkhorn distances: Lightspeed computation of optimal transport. In
Advances in neural information processing systems , pages 2292–2300.Cuturi, M. and Doucet, A. (2014). Fast computation of wasserstein barycenters. In
Inter-national Conference on Machine Learning , pages 685–693.D’Amour, A., Ding, P., Feller, A., Lei, L., and Sekhon, J. (2017). Overlap in observationalstudies with high-dimensional covariates. arXiv preprint arXiv:1711.02582 .Dehejia, R. H. and Wahba, S. (2002). Propensity score-matching methods for nonexperi-mental causal studies.
Review of Economics and statistics , 84(1):151–161.Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. (2009). Imagenet: Alarge-scale hierarchical image database. In , pages 248–255. Ieee.Ding, P., VanderWeele, T., and Robins, J. (2017). Instrumental variables as bias amplifierswith general outcome and confounding.
Biometrika , 104(2):291–302.Dorie, V. (2016). NPCI: Non-parametrics for Causal Inference. https://github.com/vdorie/npci .Eberhardt, F. (2008). Causal discovery as a game. In
Proceedings of the 2008th InternationalConference on Causality: Objectives and Assessment-Volume 6 , pages 87–96. JMLR. org.33reedman, D. A. and Berk, R. A. (2008). Weighting regressions by propensity scores.
Evaluation Review , 32(4):392–409.Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand,M., and Lempitsky, V. (2016). Domain-adversarial training of neural networks.
TheJournal of Machine Learning Research , 17(1):2096–2030.Geiger, P., Zhang, K., Schoelkopf, B., Gong, M., and Janzing, D. (2015). Causal inference byidentification of vector autoregressive processes with hidden components. In
InternationalConference on Machine Learning , pages 1917–1925.Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville,A., and Bengio, Y. (2014). Generative adversarial nets. In
Advances in neural informationprocessing systems , pages 2672–2680.Green, D. P. and Kern, H. L. (2010). Modeling heterogeneous treatment effects in large-scaleexperiments using bayesian additive regression trees. In
The annual summer meeting ofthe society of political methodology .Gretton, A., Borgwardt, K. M., Rasch, M. J., Sch¨olkopf, B., and Smola, A. (2012). A kerneltwo-sample test.
Journal of Machine Learning Research , 13(Mar):723–773.Gretton, A., Smola, A. J., Huang, J., Schmittfull, M., Borgwardt, K. M., and Sch¨olkopf, B.(2009). Covariate shift by kernel mean matching.Gruber, S. and van der Laan, M. J. (2011). tmle: An r package for targeted maximumlikelihood estimation.Hansen, B. B. (2008). The prognostic analogue of the propensity score.
Biometrika ,95(2):481–488.Hill, J. L. (2011). Bayesian nonparametric modeling for causal inference.
Journal of Com-putational and Graphical Statistics , 20(1).Hoyer, P. O., Janzing, D., Mooij, J. M., Peters, J., and Sch¨olkopf, B. (2009). Nonlinearcausal discovery with additive noise models. In
Advances in neural information processingsystems , pages 689–696.Hyttinen, A., Eberhardt, F., and J¨arvisalo, M. (2014). Constraint-based causal discovery:Conflict resolution with answer set programming. In
UAI , pages 340–349.Imbens, G. W. and Rubin, D. B. (2015).
Causal inference in statistics, social, and biomedicalsciences . Cambridge University Press.Johansson, F., Shalit, U., and Sontag, D. (2016). Learning representations for counterfactualinference. In
International Conference on Machine Learning , pages 3020–3029.Johansson, F. D., Kallus, N., Shalit, U., and Sontag, D. (2018). Learning weighted repre-sentations for generalization across designs. arXiv preprint arXiv:1802.08598 .34ohansson, F. D., Sontag, D., and Ranganath, R. (2019). Support and invertibility indomain-invariant representations. In Chaudhuri, K. and Sugiyama, M., editors,
Pro-ceedings of Machine Learning Research , volume 89 of
Proceedings of Machine LearningResearch , pages 527–536. PMLR.Kallus, N. (2016). Generalized optimal matching methods for causal inference. arXivpreprint arXiv:1612.08321 .Kallus, N. (2017). A framework for optimal matching for causal inference. In
ArtificialIntelligence and Statistics , pages 372–381.Kallus, N., Mao, X., and Zhou, A. (2018a). Interval estimation of individual-level causaleffects under unobserved confounding. arXiv preprint arXiv:1810.02894 .Kallus, N., Puli, A. M., and Shalit, U. (2018b). Removing hidden confounding by experi-mental grounding. In
Advances in neural information processing systems .Kallus, N. and Zhou, A. (2018). Confounding-robust policy improvement. In
Advances inneural information processing systems .K¨unzel, S. R., Sekhon, J. S., Bickel, P. J., and Yu, B. (2017). Meta-learners for estimatingheterogeneous treatment effects using machine learning. arXiv preprint arXiv:1706.03461 .LaLonde, R. J. (1986). Evaluating the econometric evaluations of training programs withexperimental data.
The American economic review , pages 604–620.LeCun, Y., Bengio, Y., and Hinton, G. (2015). Deep learning. nature , 521(7553):436.LeCun, Y., Bottou, L., Bengio, Y., Haffner, P., et al. (1998). Gradient-based learningapplied to document recognition.
Proceedings of the IEEE , 86(11):2278–2324.Lefortier, D., Swaminathan, A., Gu, X., Joachims, T., and de Rijke, M. (2016). Large-scale validation of counterfactual learning methods: A test-bed. arXiv preprintarXiv:1612.00367 .Li, F., Morgan, K. L., and Zaslavsky, A. M. (2018). Balancing covariates via propensityscore weighting.
Journal of the American Statistical Association , 113(521):390–400.Long, M., Cao, Y., Wang, J., and Jordan, M. (2015). Learning transferable features withdeep adaptation networks. In
International Conference on Machine Learning , pages 97–105.Louizos, C., Shalit, U., Mooij, J., Sontag, D., Zemel, R., and Welling, M. (2017). Causaleffect inference with deep latent-variable models. arXiv preprint arXiv:1705.08821 .Mansour, Y., Mohri, M., and Rostamizadeh, A. (2009). Domain adaptation: Learningbounds and algorithms. arXiv preprint arXiv:0902.3430 .M¨uller, A. (1997). Integral probability metrics and their generating classes of functions.
Advances in Applied Probability , 29(2):429–443.35ie, X. and Wager, S. (2017). Learning objectives for treatment effect estimation. arXivpreprint arXiv:1712.04912 .Nowozin, S., Cseke, B., and Tomioka, R. (2016). f-gan: Training generative neural samplersusing variational divergence minimization. In
Advances in neural information processingsystems , pages 271–279.Pearl, J. (2009).
Causality . Cambridge university press.Pearl, J. (2017). Detecting latent heterogeneity.
Sociological Methods & Research , 46(3):370–389.Pollard, D. (2012).
Convergence of stochastic processes . Springer Science & Business Media.Robins, J. M., Hernan, M. A., and Brumback, B. (2000). Marginal structural models andcausal inference in epidemiology.Rosenbaum, P. R. (2002). Overt bias in observational studies. In
Observational studies ,pages 71–104. Springer.Rosenbaum, P. R. and Rubin, D. B. (1983). The central role of the propensity score inobservational studies for causal effects.
Biometrika , 70(1):41–55.Rubin, D. B. (2005). Causal inference using potential outcomes: Design, modeling, deci-sions.
Journal of the American Statistical Association , 100(469):322–331.Schneeweiss, S., Rassen, J. A., Glynn, R. J., Avorn, J., Mogun, H., and Brookhart, M. A.(2009). High-dimensional propensity score adjustment in studies of treatment effectsusing health care claims data.
Epidemiology (Cambridge, Mass.) , 20(4):512.Shalit, U., Johansson, F., and Sontag, D. (2017). Estimating individual treatment effect:generalization bounds and algorithms. In
International Conference on Machine Learning .Shi, C., Blei, D. M., and Veitch, V. (2019). Adapting neural networks for the estimation oftreatment effects. arXiv preprint arXiv:1906.02120 .Shimodaira, H. (2000). Improving predictive inference under covariate shift by weightingthe log-likelihood function.
Journal of statistical planning and inference , 90(2):227–244.Shimoni, Y., Yanover, C., Karavani, E., and Goldschmnidt, Y. (2018). Benchmark-ing framework for performance-evaluation of causal inference analysis. arXiv preprintarXiv:1802.05046 .Shortreed, S. M. and Ertefaie, A. (2017). Outcome-adaptive lasso: Variable selection forcausal inference.
Biometrics , 73(4):1111–1122.Silva, R., Scheine, R., Glymour, C., and Spirtes, P. (2006). Learning the structure of linearlatent variable models.
Journal of Machine Learning Research , 7(Feb):191–246.Smith, J. A. and Todd, P. E. (2005). Does matching overcome LaLonde’s critique of non-experimental estimators?
Journal of econometrics , 125(1):305–353.36pirtes, P. and Glymour, C. (1991). An algorithm for fast recovery of sparse causal graphs.
Social science computer review , 9(1):62–72.Sriperumbudur, B. K., Fukumizu, K., Gretton, A., Sch¨olkopf, B., and Lanckriet, G. R.(2009). On integral probability metrics, \ phi-divergences and binary classification. arXivpreprint arXiv:0901.2698 .Swaminathan, A. and Joachims, T. (2015). Counterfactual risk minimization: Learningfrom logged bandit feedback. In International Conference on Machine Learning , pages814–823.Tian, J. and Pearl, J. (2001). Causal discovery from changes. In
Proceedings of the Seven-teenth conference on Uncertainty in artificial intelligence , pages 512–521. Morgan Kauf-mann Publishers Inc.Vapnik, V. (1998).
Statistical Learning Theory . Wiley, New York.Vapnik, V. (2013).
The nature of statistical learning theory . Springer science & businessmedia.Villani, C. (2008).
Optimal transport: old and new , volume 338. Springer Science & BusinessMedia.Wager, S. and Athey, S. (2015). Estimation and inference of heterogeneous treatmenteffects using random forests. arXiv preprint arXiv:1510.04342. https: // github. com/susanathey/ causalTree .Wager, S. and Athey, S. (2017). Estimation and inference of heterogeneous treatment effectsusing random forests.
Journal of the American Statistical Association , (just-accepted).Zhang, K., Sch¨olkopf, B., Muandet, K., and Wang, Z. (2013). Domain adaptation undertarget and conditional shift. In
International Conference on Machine Learning , pages819–827. 37
Proof of Theorem 3
We prove Theorem 3 in a generalized form. In particular, we consider the risk in predictingthe outcome Y in expectation over a treatment policy p π ( T | X ) based on observationsfrom a policy p µ ( T | X ). The risk in predicting a single potential outcome t follows as aspecial case of π ( T = t | X ) = 1. With this in mind, let R π = E X [ E T | X ∼ p π [ (cid:96) f ( X, T )]] , where (cid:96) f ( x, t ) = E [ L ( f ( x, t ) , Y ( t )) | X = x, T = t ]. As previously, we consider hypotheses f ( x, t ) = h (Φ( x ) , t ) for functions h ∈ F and embeddings Φ ∈ E . Theorem 3 (Restated).
Suppose H is a reproducing kernel Hilbert space given by abounded kernel. Suppose weak overlap holds in that E [( p π ( x, t ) /p µ ( x, t )) ] < ∞ . Assumethat n labeled samples { ( x i , t i , y i ) } ni =1 ∼ p µ and m unlabeled samples { ( x i , t i ) } m + ni = n +1 ∼ p π are available. Then,min h, Φ ,w R π ( h, Φ , w ; β ) ≤ min f ∈F R π ( f ) + O (1 / √ n + 1 / √ m ) . Proof.
Let f ∗ = Φ ∗ ◦ h ∗ ∈ arg min f ∈F R π ( f ) and let w ∗ ( x, t ) = p π, Φ (Φ ∗ ( x ) , t ) /p µ, Φ (Φ ∗ ( x ) , t ).Since min h, Φ ,w R π ( h, Φ , w ; β ) ≤ R π ( h ∗ , Φ ∗ , w ∗ ; β ), it suffices to show that R π ( h ∗ , Φ ∗ , w ∗ ; β ) = R π ( f ∗ ) + O (1 / √ n + 1 / √ m ) . We will work term by term: R π ( h ∗ , Φ ∗ , w ∗ ; β ) = 1 n n (cid:88) i =1 w i (cid:96) h (Φ( x i ) , t i ) (cid:124) (cid:123)(cid:122) (cid:125) A+ λ h R ( h ) √ n (cid:124) (cid:123)(cid:122) (cid:125) B + α IPM G (ˆ q Φ , ˆ p w k Φ ) (cid:124) (cid:123)(cid:122) (cid:125) C + λ w (cid:107) w (cid:107) n (cid:124) (cid:123)(cid:122) (cid:125) D . For term D , letting w ∗ i = w ∗ ( x i , t i ), we have that by weak overlapD = 1 n × n n (cid:88) i =1 ( w ∗ i ) = O p (1 /n ) , so that D = O p (1 / √ n ). For term A , under ignorability, each term in the sum in thefirst term has expectation equal to R π ( f ∗ ) and so, so by weak overlap and bounded secondmoments of loss, we have A = R π ( f ∗ ) + O p (1 / √ n ). For term B , since h ∗ is fixed wehave deterministically that B = O (1 / √ n ).Finally, we address term C , which when expanded can be written assup (cid:107) h (cid:107)≤ ( 1 m m (cid:88) i =1 h (Φ ∗ ( x (cid:48) i ) , t (cid:48) i ) − n n (cid:88) i =1 w ∗ i h (Φ ∗ ( x i ) , t i )) . x (cid:48)(cid:48) i , t (cid:48)(cid:48) i for i = 1 , . . . , m and x (cid:48)(cid:48)(cid:48) i , t (cid:48)(cid:48)(cid:48) i for i = 1 , . . . , n be new iid replicates of x (cid:48) , t (cid:48) , i.e.,new ghost samples drawn from the target design. By Jensen’s inequality, E [ C ] = E [ sup (cid:107) h (cid:107)≤ ( 1 m m (cid:88) i =1 h (Φ ∗ ( x (cid:48) i ) , t (cid:48) i ) − n n (cid:88) i =1 w ∗ i h (Φ ∗ ( x i ) , t i )) ]= E [ sup (cid:107) h (cid:107)≤ ( 1 m m (cid:88) i =1 ( h (Φ ∗ ( x (cid:48) i ) , t (cid:48) i ) − E [ h (Φ ∗ ( x (cid:48)(cid:48) i ) , t (cid:48)(cid:48) i )]) − n n (cid:88) i =1 ( w ∗ i h (Φ ∗ ( x i ) , t i ) − E [ h (Φ ∗ ( x (cid:48)(cid:48)(cid:48) i ) , t (cid:48)(cid:48)(cid:48) i )])) ] ≤ E [ sup (cid:107) h (cid:107)≤ ( 1 m m (cid:88) i =1 ( h (Φ ∗ ( x (cid:48) i ) , t (cid:48) i ) − h (Φ ∗ ( x (cid:48)(cid:48) i ) , t (cid:48)(cid:48) i )) − n n (cid:88) i =1 ( w ∗ i h (Φ ∗ ( x i ) , t i ) − h (Φ ∗ ( x (cid:48)(cid:48)(cid:48) i ) , t (cid:48)(cid:48)(cid:48) i ))) ] ≤ E [ sup (cid:107) h (cid:107)≤ ( 1 m m (cid:88) i =1 ( h (Φ ∗ ( x (cid:48) i ) , t (cid:48) i ) − h (Φ ∗ ( x (cid:48)(cid:48) i ) , t (cid:48)(cid:48) i ))) ]+ 2 E [ sup (cid:107) h (cid:107)≤ ( 1 n n (cid:88) i =1 ( w ∗ i h (Φ ∗ ( x i ) , t i ) − h (Φ ∗ ( x (cid:48)(cid:48)(cid:48) i ) , t (cid:48)(cid:48)(cid:48) i ))) ]Let ξ i ( h ) = h (Φ ∗ ( x (cid:48) i ) , t (cid:48) i ) − h (Φ ∗ ( X (cid:48) qi ) and let ζ i ( h ) = w ∗ i h (Φ ∗ ( x i ) , t i ) − h (Φ ∗ ( x (cid:48)(cid:48)(cid:48) i ) , t (cid:48)(cid:48)(cid:48) i ).Note that for every h , E [ ζ i ( h )] = E [ ξ i ( h )] = 0 . Moreover, E [ (cid:107) ζ i (cid:107) ] ≤ E [ K (Φ ∗ ( x (cid:48) i ) , t (cid:48) i , Φ ∗ ( x (cid:48) i ) , t (cid:48) i )] ≤ M .
Similarly, E [ (cid:107) ξ i (cid:107) ] ≤ E [( w ∗ i ) ] M + 2 M ≤ M (cid:48) < ∞ because of weak overlap. Let ζ (cid:48) i for i = 1 , . . . , n be iid replicates of ζ i (ghost sample) and let (cid:15) i be iid Rademacher randomvariables. Because H is a Hilbert space, we have that sup (cid:107) h (cid:107)≤ ( A ( h )) = (cid:107) A (cid:107) = (cid:104) A, A (cid:105) .Therefore, by Jensen’s inequality, E [ sup (cid:107) h (cid:107)≤ ( 1 n n (cid:88) i =1 ( w ∗ i h (Φ ∗ ( x i ) , t i ) − h (Φ ∗ ( x (cid:48)(cid:48)(cid:48) i ) , t (cid:48)(cid:48)(cid:48) i ))) ]= E [ sup (cid:107) h (cid:107)≤ ( 1 n n (cid:88) i =1 ζ i ( h )) ]= E [ sup (cid:107) h (cid:107)≤ ( 1 n n (cid:88) i =1 ( ζ i ( h ) − E [ ζ (cid:48) i ( h )])) ] ≤ E [ sup (cid:107) h (cid:107)≤ ( 1 n n (cid:88) i =1 ( ζ i ( h ) − ζ (cid:48) i ( h ))) ]= E [ sup (cid:107) h (cid:107)≤ ( 1 n n (cid:88) i =1 (cid:15) i ( ζ i ( h ) − ζ (cid:48) i ( h ))) ] ≤ n E [ sup (cid:107) h (cid:107)≤ ( n (cid:88) i =1 (cid:15) i ζ i ( h )) ]39able 3: Hyperparameters and ranges.Parameter RangeImbalance parameter, α { k/ } k = − Num. of representation layers { , , } Num. of hypothesis layers { , , } Dim. of representation layers { , , , } Dim. of hypothesis layers { , , , } Batch size { , , , } = 4 n E [ (cid:107) n (cid:88) i =1 (cid:15) i ζ i (cid:107) ]= 4 n E [ n (cid:88) i,j =1 (cid:15) i (cid:15) j (cid:104) ζ i , ζ j (cid:105) ]= 4 n E [ n (cid:88) i =1 (cid:107) ζ i (cid:107) ]= 4 n n (cid:88) i =1 E [ (cid:107) ζ i (cid:107) ] ≤ M (cid:48) n An analogous argument can be made of ξ i ’s, showing that E [ C ] = O (1 /n ) and henceC = O (1 / √ n ) by Markov’s inequality. B Experiment details
See Table 3 for a description of hyperparameters and search ranges for TARNet, CFR Wassand CFR MMD.
C Architecture for joint learning of sample weights
For an illustration of the re-weighed CFR estimator, see Figure 7.
D Minimization of approximate Wasserstein distances
Computing (and minimizing) the Wasserstein distance traditionally involves solving a lin-ear program, which may be prohibitively expensive for many practical applications. Cuturi(2013) showed that introducing entropic regularization in the optimization problem resultsin an approximation computable through the Sinkhorn-Knopp matrix scaling algorithm, at40
𝑋 Φ … …… 𝑓 & 𝑓 ’ Covariates Shared representationPredicted potential outcomesNeural network layers … 𝑤 𝑅 * 𝑑 𝑝 -,&* / , 𝑝 -,’* 𝑇 𝑌 I n t e r v e n ti on O u t c o m e Sample weighting Group distanceWeighted risk
Figure 7: Illustration of the Re-weighted Counterfactual Regression (RCFR) estimator.Green boxes indicate inputs, white boxes outputs and loss terms, yellow boxes sharedrepresentations and blue/red boxes estimators of potential outcomes. Solid lines indicatetransformation part of the prediction function and dashed lines indicate computations partof the learning procedure.
Algorithm 1
Computing the stochastic gradient of the Wasserstein distance Input:
Factual ( x , t , y ) , . . . , ( x n , t n , y n ), representation network Φ W with currentweights by W Randomly sample a mini-batch with m treated and m (cid:48) control units ( x i , , y i ) , . . . , ( x i m , , y i m ) , ( x i m +1 , , y i m +1 ) , . . . , ( x i m , , y i m ) Calculate the m × m pairwise distance matrix between all treatment and control pairs M (Φ W ): M kl (Φ) = (cid:107) Φ W ( x i k ) − Φ W ( x i m + l ) (cid:107) Calculate the approximate optimal transport matrix T ∗ using Algorithm 3 of Cuturiand Doucet (2014), with input M (Φ W ) Calculate the gradient: g = ∇ W (cid:104) T ∗ , M (Φ W ) (cid:105) orders of magnitude faster speed. The approximation, called Sinkhorn distances, is com-puted using a fixed-point iteration involving repeated multiplication with a kernel matrix K . We use the algorithm of Cuturi (2013) in our framework by differentiating through theiterations. See Algorithm 1 for an overview of how to compute the gradient g in Algo-rithm 19. When computing g , disregarding the gradient ∇ W T ∗ amounts to minimizing anupper bound on the Sinkhorn transport. More advanced ideas for stochastic optimizationof this distance have recently proposed by Aude et al. (2016), and might be used in futurework.While our framework is agnostic to the parameterization of Φ, our experiments focus onthe case where Φ is a neural network. For convenience of implementation, we may representthe fixed-point iterations of the Sinkhorn algorithm as a recurrent neural network, wherethe states u t evolve according to u t +1 = n t ./ ( n c K (1 ./ ( u (cid:62) t K ) (cid:62) )) . Here, K is a kernel matrix corresponding to a metric such as the euclidean distance, K ij =41 − λ (cid:107) Φ( x i ) − Φ( x j ) (cid:107) , and n c , n tt