{"title": "Bayesian Inference of Individualized Treatment Effects using Multi-task Gaussian Processes", "book": "Advances in Neural Information Processing Systems", "page_first": 3424, "page_last": 3432, "abstract": "Predicated on the increasing abundance of electronic health records, we investigate the problem of inferring individualized treatment effects using observational data. Stemming from the potential outcomes model, we propose a novel multi-task learning framework in which factual and counterfactual outcomes are modeled as the outputs of a function in a vector-valued reproducing kernel Hilbert space (vvRKHS). We develop a nonparametric Bayesian method for learning the treatment effects using a multi-task Gaussian process (GP) with a linear coregionalization kernel as a prior over the vvRKHS. The Bayesian approach allows us to compute individualized measures of confidence in our estimates via pointwise credible intervals, which are crucial for realizing the full potential of precision medicine. The impact of selection bias is alleviated via a risk-based empirical Bayes method for adapting the multi-task GP prior, which jointly minimizes the empirical error in factual outcomes and the uncertainty in (unobserved) counterfactual outcomes. We conduct experiments on observational datasets for an interventional social program applied to premature infants, and a left ventricular assist device applied to cardiac patients wait-listed for a heart transplant. In both experiments, we show that our method significantly outperforms the state-of-the-art.", "full_text": "Bayesian Inference of Individualized Treatment\n\nEffects using Multi-task Gaussian Processes\n\nAhmed M. Alaa\n\nElectrical Engineering Department\nUniversity of California, Los Angeles\n\nMihaela van der Schaar\n\nDepartment of Engineering Science\n\nUniversity of Oxford\n\n=D\u0006A@\u0006=\u0006==(K?\u0006=\u0002A@K\n\n\u0006ED=A\u0006=\u0002L=\u0006@AHI?D==H(A\u0006C\u0002\u0006N\u0002=?\u0002K\u0006\n\nAbstract\n\nPredicated on the increasing abundance of electronic health records, we investi-\ngate the problem of inferring individualized treatment effects using observational\ndata. Stemming from the potential outcomes model, we propose a novel multi-\ntask learning framework in which factual and counterfactual outcomes are mod-\neled as the outputs of a function in a vector-valued reproducing kernel Hilbert\nspace (vvRKHS). We develop a nonparametric Bayesian method for learning the\ntreatment effects using a multi-task Gaussian process (GP) with a linear coregion-\nalization kernel as a prior over the vvRKHS. The Bayesian approach allows us\nto compute individualized measures of con\ufb01dence in our estimates via pointwise\ncredible intervals, which are crucial for realizing the full potential of precision\nmedicine. The impact of selection bias is alleviated via a risk-based empirical\nBayes method for adapting the multi-task GP prior, which jointly minimizes the\nempirical error in factual outcomes and the uncertainty in (unobserved) counter-\nfactual outcomes. We conduct experiments on observational datasets for an inter-\nventional social program applied to premature infants, and a left ventricular assist\ndevice applied to cardiac patients wait-listed for a heart transplant. In both experi-\nments, we show that our method signi\ufb01cantly outperforms the state-of-the-art.\n\n1 Introduction\n\nClinical trials entail enormous costs: the average costs of multi-phase trials in vital therapeutic ar-\neas such as the respiratory system, anesthesia and oncology are $115.3 million, $105.4 million, and\n$78.6 million, respectively [1]. Moreover, due to the dif\ufb01culty of patient recruitment, randomized\ncontrolled trials often exhibit small sample sizes, which hinders the discovery of heterogeneous ther-\napeutic effects across different patient subgroups [2]. Observational studies are cheaper and quicker\nalternatives to clinical trials [3, 4]. With the advent of electronic health records (EHRs), currently\ndeployed in more than 75% of hospitals in the U.S. according to the latest ONC data brief1, there is\na growing interest in using machine learning to infer heterogeneous treatment effects from readily\navailable observational data in EHRs. This interest glints in recent initiatives such as STRATOS\n[3], which focuses on guiding observational medical research, in addition to various recent works\non causal inference from observational data developed by the machine learning community [4-11].\nMotivated by the plethora of EHR data and the potentiality of precision medicine, we address the\nproblem of estimating individualized treatment effects (i.e. causal inference) using observational\ndata. The problem differs from standard supervised learning in that for every subject in an observa-\ntional cohort, we only observe the \"factual\" outcome for a speci\ufb01c treatment assignment, but never\nobserve the corresponding \"counterfactual\" outcome2, without which we can never know the true\n\n1DJJFI\u0003\u0002\u0002MMM\u0002DA=\u0006JDEJ\u0002C\u0006L\u0002IEJAI\u0002@AB=K\u0006J\u0002BE\u0006AI\u0002>HEABI\u0002\n2Some works refer to this setting as the \"logged bandits with feedback\" [12, 13].\n\n31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.\n\n\ftreatment effect [4-9]. Selection bias creates a discrepancy in the feature distributions for the treated\nand control patient groups, which makes the problem even harder. Much of the classical works have\nfocused on the simpler problem of estimating average treatment effects via unbiased estimators\nbased on propensity score weighting (see [14] and the references therein). More recent works learn\nindividualized treatment effects via regression models that view the subjects\u2019 treatment assignments\nas input features [4-13]. We provide a thorough review on these works in Section 3.\nContribution At the heart of this paper lies a novel conception of causal inference as a multi-task\nlearning problem. That is, we view a subject\u2019s potential outcomes as the outputs of a vector-valued\nfunction in a reproducing kernel Hilbert space (vvRKHS) [15]. We propose a Bayesian approach\nfor learning the treatment effects through a multi-task Gaussian process (GP) prior over the popu-\nlations\u2019 potential outcomes. The Bayesian perspective on the multi-task learning problem allows\nreasoning about the unobserved counterfactual outcomes, giving rise to a loss function that quanti-\n\ufb01es the Bayesian risk of the estimated treatment effects while taking into account the uncertainty\nin counterfactual outcomes without explicit propensity modeling. Furthermore, we show that op-\ntimizing the multi-task GP hyper-parameters via risk-based empirical Bayes [16] is equivalent to\nminimizing the empirical error in the factual outcomes, with a regularizer that is proportional to the\nposterior uncertainty (variance) in counterfactual outcomes. We provide a feature space interpreta-\ntion of our method showing its relation to previous works on domain adaptation [6, 8], empirical\nrisk minimization [13], and tree-based learning [4, 5, 7, 9].\nThe Bayesian approach allows us to compute individualized measures of con\ufb01dence in our esti-\nmates via pointwise credible intervals. With the exception of [5] and [9], all previous works do\nnot associate their estimates with con\ufb01dence measures, which hinders their applicability in formal\nmedical research. While Bayesian credible sets do not guarantee frequentist coverage, recent results\non the \"honesty\" (i.e. frequentist coverage) of adaptive credible sets in nonparametric regression\nmay extend to our setting [16]. In particular, [Theorem 1, 16] shows that \u2013under some extrapola-\ntion conditions\u2013 adapting a GP prior via risk-based empirical Bayes guarantees honest credible sets:\ninvestigating the validity of these results in our setting is an interesting topic for future research.\n\n2 Problem Setup\n\n(cid:12)(cid:12)(cid:12) Xi = x\n\n(cid:0) Y (0)\n\n; Y (0)\n\n2 R that are drawn from a distribution (Y (1)\n\nWe consider the setting in which a speci\ufb01c treatment is applied to a population of subjects, where\neach subject i possesses a d-dimensional feature Xi 2 X , and two (random) potential outcomes\n)jXi = x (cid:24) P(:jXi = x), and\nY (1)\ni\ncorrespond to the subject\u2019s response with and without the treatment, respectively. The realized causal\n]\n)j Xi = x.\neffect of the treatment on subject i manifests through the random variable (Y (1)\nHence, we de\ufb01ne the individualized treatment effect (ITE) for subjects with a feature Xi = x as\n\n(cid:0) Y (0)\n\n; Y (0)\n\n[\n\ni\n\ni\n\ni\n\ni\n\ni\n\nT (x) = E\n\nY (1)\ni\n\ni\n\ni\n\n:\n\nand Y (1(cid:0)Wi)\n\n(1)\nOur goal is to conduct the causal inference task of estimating the function T (x) from an ob-\nservational dataset D, which typically comprises n independent samples of the random tuple\ng; where Wi 2 f0; 1g is a treatment assignment indicator that indicates whether or\nfXi; Wi; Y (Wi)\nnot subject i has received the treatment under consideration. The outcomes Y (Wi)\nare\nknown as the factual and the counterfactual outcomes, respectively [6, 9]. Treatment assignments\nare generally dependent on features, i.e. Wi \u0338?? Xi. The conditional distribution P(Wi = 1jXi = x),\nalso known as the propensity score of subject i [13, 14], re\ufb02ects the underlying policy for assigning\nthe treatment to subjects. Throughout this paper, we respect the standard assumptions of uncon-\nfoundedness (or ignorability) and overlap: this setting is known in the literature as the \"potential\noutcomes model with unconfoundedness\" [4-11].\nIndividual-based causal inference using observational data is challenging. Since we only observe\n(cid:0) Y (0)\none of the potential outcomes for every subject i, we never observe the treatment effect Y (1)\nfor any of the subjects, and hence we cannot resort to standard supervised learning to estimate T (x).\nMoreover, the dataset D exhibits selection bias, which may render the estimates of T (x) inaccurate\nif the treatment assignment for individuals with Xi = x is strongly biased (i.e. P(Wi = 1jXi = x) is\nclose to 0 or 1). Since our primary motivation for addressing this problem comes from its application\npotential in precision medicine, it is important to associate our estimate of T (:) with a pointwise\nmeasure of con\ufb01dence in order to properly guide therapeutic decisions for individual patients.\n\ni\n\ni\n\ni\n\ni\n\n2\n\n\f3 Multi-task Learning for Causal Inference\n\nVector-valued Potential Outcomes Function We adopt the following signal-in-white-noise model\nfor the potential outcomes:\n\ni\n\ni = fw(Xi) + \u03f5i;w; w 2 f0; 1g;\nY (w)\n\nw) is a Gaussian noise variable. It follows from (2) that E[Y (w)\n\n(2)\nwhere \u03f5i;w (cid:24) N (0; (cid:27)2\nj Xi = x] =\nfw(x); and hence the ITE can be estimated as ^T (x) = ^f1(x) (cid:0) ^f0(x). Most previous works that\nestimate T (x) via direct modeling learn a single-output regression model that treats the treatment\nassignment as an input feature, i.e. fw(x) = f (x; w); f (:; :) : X (cid:2) f0; 1g ! R, and estimate\nthe ITE as ^T (x) = ^f (x; 1) (cid:0) ^f (x; 0) [5-9]. We take a different perspective by introducing a new\nmulti-output regression model comprising a potential outcomes (PO) function f (:) : X ! R2, with\nd inputs (features) and 2 outputs (potential outcomes); the ITE estimate is the projection of the\nestimated PO function on the vector e = [(cid:0)1 1]T , i.e. ^T (x) = ^f T (x) e.\nConsistent pointwise estimation of the ITE function T (x) requires restricting the PO function\nf (x) to a smooth function class [9]. To this end, we model the PO function f (x) as belonging\nto a vector-valued Reproducing Kernel Hilbert Space (vvRKHS) HK equipped with an inner\nproduct \u27e8:; :\u27e9HK, and with a reproducing kernel K : X (cid:2) X ! R2(cid:2)2, where K is a (symmetric)\npositive semi-de\ufb01nite matrix-valued function [15]. Our choice for the vvRKHS is motivated by\nits algorithmic advantages; by virtue of the representer theorem, we know that learning the PO\nfunction entails estimating a \ufb01nite number of coef\ufb01cients evaluated at the input points fXign\ni=1 [17].\n\nMulti-task Learning The vector-valued model for the PO function conceptualizes causal inference\ngn\nas a multi-task learning problem. That is, D = fXi; Wi; Y (Wi)\ni=1 can be thought of as comprising\ntraining data for two learning tasks with target functions f0(:) and f1(:), with Wi acting as the \"task\nindex\" for the ith training point [15]. For an estimated PO function ^f (x), the true loss functional is\n\ni\n\n(3)\n\n\u222b\n\n(\n^f T (x) e (cid:0) T (x)\n\n)2 (cid:1) P(X = x) dx:\n\nL(^f ) =\n\nx2X\n\nThe loss functional in (3) is known as the precision in estimating heterogeneous effects (PEHE), and\nis commonly used to quantify the \"goodness\" of ^T (x) as an estimate of T (x) [4-6, 8]. A conspicuous\nchallenge that arises when learning the \"PEHE-optimal\" PO function f is that we cannot compute the\ngn\nempirical PEHE for a particular f 2 HK since the treatment effect samples fY (1)\ni=1 are not\navailable in D. On the other hand, using a loss function that evaluates the losses of f0(x) and f1(x)\nseparately (as in conventional multi-task learning [Sec. 3.2, 15]) can be highly problematic: in the\npresence of a strong selection bias, the empirical loss for f (:) with respect to factual outcomes may\nnot generalize to counterfactual outcomes, leading to a large PEHE loss. In order to gain insight into\n(\nthe structure of the optimal PO function, we consider an \"oracle\" that has access to counterfactual\noutcomes. For such an oracle, the \ufb01nite-sample empirical PEHE is\n^f T (Xi) e (cid:0) (1 (cid:0) 2Wi)\n\n^L(^f ; K; Y(W); Y(1(cid:0)W)) =\n\nY (1(cid:0)Wi)\n\n(cid:0) Y (Wi)\n\n))2\n\nn\u2211\n\n(cid:0)Y (0)\n\n(\n\n(4)\n\n;\n\ni\n\ni\n\ni\n\ni\n\n1\nn\n\ni=1\n\n]i and Y(1(cid:0)W) = [Y (1(cid:0)Wi)\n\nwhere Y(W) = [Y (Wi)\noptimal PO function f (:) is given by the following representer Theorem.\nTheorem 1 (Representer Theorem for Oracle Causal Inference). For any ^f\n\ni\n\ni\n\n]i. When Y(1(cid:0)W) is accessible, the PEHE-\n\n(cid:3) 2 HK satisfying\n\n(cid:3)\n^f\n\n= arg min\n^f2HK\n(cid:3)\n(:) = eT ^f\n\n^L(^f ; K; Y(W); Y(1(cid:0)W)) + (cid:21)jj^fjj2HK\n\n(5)\n(:) 2 spanf ~K(:; X1); : : : ; ~K(:; Xn)g; where ~K(:; :) = eT K(:; :) e.\n\n; (cid:21) 2 R+;\n\n(cid:3)\n\n\u2211\n\nwe have that ^T\nThat is, ^T\n\n(cid:3)\n\n(:) admits a representation ^T\n\n(cid:3)\n\n(6)\nwhere \u2299 denotes component-wise product, ~K(X; X) = ( ~K(Xi; Xj))i;j, W = [W1; : : : ; Wn]T . (cid:3)\n\n(cid:11) = ( ~K(X; X) + n (cid:21) I)\n\nn\n\ni=1 (cid:11)i ~K(:; Xi); (cid:11) = [(cid:11)1; : : : ; (cid:11)n]T , where\n\n(:) =\n(cid:0)1((1 (cid:0) 2W) \u2299 (Y(1(cid:0)W) (cid:0) Y(W)));\n\n3\n\n\f(cid:3)\n\nA Bayesian Perspective Theorem 1 follows directly from the generalized representer Theorem\n[17] (A proof is provided in [17]), and it implies that regularized empirical PEHE minimization\nin vvRKHS is equivalent to Bayesian inference with a Gaussian process (GP) prior [Sec. 2.2, 15].\nTherefore, we can interpret ^T\n(:) as the posterior mean of T (:) given a GP prior with a covariance\nkernel ~K, i.e. T (cid:24) GP(0; ~K). We know from Theorem 1 that ~K = eT Ke, hence the prior on T (:)\nis equivalent to a multi-task GP prior on the PO function f (:) with a kernel K, i.e. f (cid:24) GP(0; K).\nThe Bayesian view of the problem is advantageous for two reasons. First, as discussed earlier, it\nallows computing individualized (pointwise) measures of uncertainty in ^T (:) via posterior credible\nintervals. Second, it allows reasoning about the unobserved counterfactual outcomes in a Bayesian\nfashion, and hence provides a natural proxy for the oracle learner\u2019s empirical PEHE in (4). Let\n(cid:18) 2 (cid:2) be a kernel hyper-parameter that parametrizes the multi-task GP kernel K(cid:18). We de\ufb01ne the\nBayesian PEHE risk R((cid:18); ^f ;D) for a point estimate ^f as follows\n\n[\n\n]\n\n(cid:12)(cid:12)(cid:12) D\n\n:\n\n(cid:3)\n^f\n\n(cid:3)\n\nR((cid:18); ^f ;D) = E(cid:18)\n\n^L(^f ; K(cid:18); Y(W); Y(1(cid:0)W))\n\n= E(cid:18)(cid:3) [ f jD ]; (cid:18)\n\n(7)\nThe expectation in (7) is taken with respect to Y(1(cid:0)W)jD. The Bayesian PEHE risk R((cid:18); ^f ;D) is\nsimply the oracle learner\u2019s empirical loss in (4) marginalized over the posterior distribution of the\nunobserved counterfactuals Y(1(cid:0)W), and hence it incorporates the posterior uncertainty in coun-\n(cid:3) and\nterfactual outcomes without explicit propensity modeling. The optimal hyper-parameter (cid:18)\n(cid:3)\ninterpolant ^f\n(:) that minimize the Bayesian PEHE risk are given in the following Theorem.\n) of R((cid:18); ^f ;D) is given by\n(cid:3)\nTheorem 2 (Risk-based Empirical Bayes). The minimizer (^f\n(cid:13)(cid:13)(cid:13)Var(cid:18)[ Y(1(cid:0)W) jD ]\n(cid:13)(cid:13)(cid:13)\n}\n|\n\n26664(cid:13)(cid:13)(cid:13)Y(W) (cid:0) E(cid:18)[ f jD ]\n(cid:13)(cid:13)(cid:13)2\n{z\n}\n|\nwhere Var(cid:18)[:j:] is the posterior variance and \u2225:\u2225p is the p-norm.\n(cid:3)\nThe proof is provided in Appendix A. Theorem 2 shows that hyper-parameter selection via risk-\nbased empirical Bayes is instrumental in alleviating the impact of selection bias. This is because,\n(cid:3) with respect to factual outcomes, and\nas the Theorem states, (cid:18)\n(cid:3) carves a\nuses the posterior variance of the counterfactual outcomes as a regularizer. Hence, (cid:18)\nkernel that not only \ufb01ts factual outcomes, but also generalizes well to counterfactuals. It comes\n= E(cid:18)(cid:3)[ f jD ]; E(cid:18)(cid:3)[ f jD; Y(1(cid:0)W) ] is equivalent to the oracle\u2019s solution in\n(cid:3)\nas no surprise that ^f\nTheorem 1, hence by the law of iterated expectations, E(cid:18)(cid:3)[ f jD ] = E(cid:18)(cid:3) [ E(cid:18)(cid:3)[ f jD; Y(1(cid:0)W) ]jD ]\nis the oracle\u2019s solution marginalized over the posterior distribution of counterfactuals.\n\n(cid:3) minimizes the empirical loss of ^f\n\nPosterior counterfactual variance\n\nEmpirical factual error\n\n37775 ;\n\n= arg min\n(cid:18)2(cid:2)\n\n{z\n\n(cid:3)\n\n; (cid:18)\n\n+\n\n2\n\n1\n\nFigure 1: Pictorial depiction for model selection via risk-based empirical Bayes.\n\nRelated Works A feature space interpretation of Theorem 2 helps creating a conceptual equiv-\nalence between our method and previous works. For simplicity of exposition, consider a \ufb01nite-\ndimensional vvRKHS in which the PO function resides: we can describe such a space in terms of\na feature map (cid:8) : X ! Rp; where K(x; x\n)\u27e9 [Sec. 2.3, 15]. Every PO function\nf 2 HK can be represented as f = \u27e8(cid:11); (cid:8)(x)\u27e9, and hence the two response surfaces fo(:) and f1(:)\n\n) = \u27e8(cid:8)(x); (cid:8)(x\n\u2032\n\n\u2032\n\n4\n\n\fare represented as hyperplanes in the transformed feature space as depicted in Fig. 1 (right). The\nrisk-based empirical Bayes method attempts to \ufb01nd a feature map (cid:8) and two hyperplanes that best\n\ufb01t the factual outcomes (right panel in Fig. 1) while minimizing the posterior variance in coun-\nterfactual outcomes (middle panel in Fig. 1). This conception is related to that of counterfactual\nregression [6, 8], which builds on ideas from co-variate shift and domain adaptation [19] in order to\njointly learn a response function f and a \"balanced\" representation (cid:8) that makes the distributions\nP((cid:8)(Xi = x)jWi = 1) and P((cid:8)(Xi = x)jWi = 0) similar. Our work differs from [6, 8] in the fol-\nlowing aspects. First, our Bayesian multi-task formulation provides a direct estimate of the PEHE:\n(7) is an unbiased estimator of the \ufb01nite-sample version of (3). Contrarily, [Eq. 2, 6] creates a coarse\nproxy for the PEHE by using the nearest-neighbor factual outcomes in replacement of counterfactu-\nals, whereas [Eq. 3, 8] optimizes a generalization bound which may largely overestimate the true\nPEHE for particular hypothesis classes. [6] optimizes the algorithm\u2019s hyper-parameters by assum-\ning (unrealistically) that counterfactuals are available in a held-out sample, whereas [8] uses an ad\nhoc nearest-neighbor approximation. Moreover, unlike the case in [6], our multi-task formulation\nprotects the interactions between Wi and Xi from being lost in high-dimensional feature spaces.\nMost of the previous works estimate the ITE via co-variate adjustment (G-computation formula)\n[4, 5, 7, 11, 20]; the most remarkable of these methods are the nonparametric Bayesian additive\nregression trees [5] and causal forests [4, 9]. We provide numerical comparisons with both methods\nin Section 5. [11] also uses Gaussian processes, but with the focus of modeling treatment response\ncurves over time. Counterfactual risk minimization is another framework that is applicable only\nwhen the propensity score P(Wi = 1jXi = x) is known [12, 13]. [25] uses deep networks to infer\ncounterfactuals, but requires some of the data to be drawn from a randomized trial.\n\n4 Causal Multi-task Gaussian Processes (CMGPs)\nIn this Section, we provide a recipe for Bayesian causal inference with the prior f (cid:24) GP(0; K(cid:18)).\nWe call this model a Causal Multi-task Gaussian Process (CMGP).\nConstructing the CMGP Kernel As it is often the case in medical settings, the two response sur-\nfaces f0(:) and f1(:) may display different levels of heterogeneity (smoothness), and may have dif-\nferent relevant features. Standard intrinsic coregionalization models for constructing vector-valued\nkernels impose the same covariance parameters for all outputs [18], which limits the interaction be-\ntween the treatment assignments and the patients\u2019 features. To that end, we construct a linear model\nof coregionalization (LMC) [15], which mixes two intrinsic coregionalization models as follows\n\n\u2032\n\nK(cid:18)(x; x\n\n(8)\n); w 2 f0; 1g; is the radial basis function (RBF) with automatic relevance determi-\nwhere kw(x; x\nd;w); with\nnation, i.e. kw(x; x\n\u2113d;w being the length scale parameter of the dth feature in kw(:; :), whereas A0 and A1 are given by\n\n; Rw = diag(\u21132\n\n) = A0 k0(x; x\n\n) + A1 k1(x; x\n\n2;w; : : : ; \u21132\n\n) = exp\n\n1;w; \u21132\n\n)T R\n\n);\n\n)\n\n\u2032\n\nw (x (cid:0) x\n\u2032\n(cid:0)1\n]\n\n\u2032\n\n\u2032\n\n)\n[\n\n\u2032\n\n]\n\n((cid:0) 1\n\n\u2032\n\n2 (x (cid:0) x\n[\n\n(cid:12)2\n00\n(cid:26)0\n\nA0 =\n\n(cid:26)0\n(cid:12)2\n01\n\n; A1 =\n\n(cid:12)2\n10\n(cid:26)1\n\n(cid:26)1\n(cid:12)2\n11\n\n:\n\n(9)\n\nij)ij and ((cid:26)i)i determine the variances and correlations of the two response\nThe parameters ((cid:12)2\nsurfaces f0(x) and f1(x). The LMC kernel introduces degrees of freedom that allow the two re-\nsponse surfaces to have different covariance functions and relevant features. When (cid:12)00 >> (cid:12)01 and\n(cid:12)11 >> (cid:12)10; the length scale parameter \u2113d;w can be interpreted as the relevance of the dth feature\nto the response surface fw(:). The set of all hyper-parameters is (cid:18) = ((cid:27)0; (cid:27)1; R0; R1; A0; A1).\n\nAdapting the Prior via Risk-based Empirical Bayes In order to avoid over\ufb01tting to the factual out-\ncomes Y(W); we evaluate the empirical error in factual outcomes via leave-one-out cross-validation\n(LOO-CV) with Bayesian regularization [24]; the regularized objective function is thus given by\n^R((cid:18);D) = (cid:17)0 Q((cid:18)) + (cid:17)1 \u2225(cid:18)\u22252\n\nQ((cid:18)) =\n\n(10)\nand D(cid:0)i is the dataset D with subject i removed, whereas (cid:17)0 and (cid:17)1 are the Bayesian regularization\nparameters. For the second level of inference, we use the improper Jeffrey\u2019s prior as an ignorance\n\ni=1\n\n+\n\n1\n\n;\n\nY (Wi)\ni\n\n(cid:0) E(cid:18)[f (Xi)jD(cid:0)i]\n\n2; where\n\n(cid:13)(cid:13)(cid:13)Var(cid:18)[ Y(1(cid:0)W) jD ]\n\n(cid:13)(cid:13)(cid:13)\n\nn\u2211\n\n(\n\n)2\n\n5\n\n\fprior for the regularization parameters, i.e. P((cid:17)0) / 1\n. This allows us to\nintegrate out the regularization parameters [Sec. 2.1, 24], leading to a revised objective function\n^R((cid:18);D) = n log(Q((cid:18))) + (10 + 2 d) log(\u2225(cid:18)\u22252\n2) [Eq. (15), 24]. It is important to note that LOO-CV\nwith squared loss has often been considered to be unfavorable in ordinary GP regression as it leaves\none degree of freedom undetermined [Sec. 5.4.2, 5]; this problem does not arise in our setting\nsince the term\n1 involves all the variance parameters, and hence the objective\nfunction ^R((cid:18);D) does not depend solely on the posterior mean.\n\n(cid:13)(cid:13)Var(cid:18)[ Y(1(cid:0)W) jD ]\n\nand P((cid:17)1) / 1\n\n(cid:13)(cid:13)\n\n(cid:17)0\n\n(cid:17)1\n\nCausal Inference via CMGPs Algorithm 1 sums up the entire causal inference procedure.\nIt\n\ufb01rst invokes the routine 1\u0006EJE=\u0006E\u0007A\u0002DOFAHF=H=\u0006AJAHI, which uses the sample variance and up-\ncrossing rate of Y(W) to initialize (cid:18) (see Appendix B). Such an automated initialization procedure\nallows running our method without any user-de\ufb01ned inputs, which facilitates its usage by researchers\nconducting observational studies. Having initialized (cid:18) (line 3), the algorithm \ufb01nds a locally optimal\n(cid:3) using gradient descent (lines 5-12), and then estimates the ITE function and the associated credi-\n(cid:18)\ngWi=1]T ,\nble intervals (lines 13-17). (X = [fXigWi=0;fXigWi=1]T , Y = [fY (Wi)\ndy, and K(cid:18)(x) = (K(cid:18)(x; Xi))i.)\n(cid:6) = diag((cid:27)2\n\ngWi=0;fY (Wi)\n\ni Wi, erf(x) = 1p\n\n1 In1 ), n1 =\n\n0 In(cid:0)n1; (cid:27)2\n\n\u2211\n\n(cid:0)y2\n\n\u222b\n\nx(cid:0)x e\n\n(cid:25)\n\ni\n\ni\n\nWe use a re-parametrized version of the\nAdaptive Moment Estimation (ADAM)\ngradient descent algorithm for optimiz-\ning (cid:18) [21]; we \ufb01rst apply a transfor-\nmation \u03d5 = exp((cid:18)) to ensure that\nall covariance parameters remain posi-\ntive, and then run ADAM to minimize\n^R(log(\u03d5t);D). The ITE function is es-\ntimated as the posterior mean of the\nCMGP (line 14). The credible inter-\nval C(cid:13)(x) with a Bayesian coverage of\n(cid:13) for a subject with feature x is de-\n\ufb01ned as P(cid:18)(T (x) 2 C(cid:13)(x)) = (cid:13), and\nis computed straightforwardly using the\nerror function of the normal distribution\n(lines 15-17). The computational bur-\nden of Algorithm 1 is dominated by the\nO(n3) matrix inversion in line 13; for\nlarge observational studies, this can be\nameliorated using conventional sparse\napproximations [Sec. 8.4, 23].\n\n5 Experiments\n\nAlgorithm 1 Causal Inference via CMGPs\n1: Input: Observational dataset D, Bayesian coverage (cid:13)\n2: Output: ITE function ^T (x), credible intervals C(cid:13)(x)\n3: (cid:18) 1\u0006EJE=\u0006E\u0007A\u0002DOFAHF=H=\u0006AJAHI(D)\n4: \u03d50 exp((cid:18)); t 0; mt 0; vt 0;\n5: repeat\n6: mt+1 (cid:12)1 mt + (1(cid:0) (cid:12)1) (cid:1) \u03d5t \u2299 \u2207\u03d5 ^R(log(\u03d5t);D)\nvt+1 (cid:12)2 vt +(1(cid:0)(cid:12)2) (cid:1) (\u03d5t \u2299 \u2207\u03d5 ^R(log(\u03d5t);D))2\n7:\n^mt+1 mt=(1 (cid:0) (cid:12)t\n8:\n\u03d5t+1 \u03d5t \u2299 exp\n9:\nt t + 1\n10:\n11: until convergence\n(cid:3) log(\u03d5t(cid:0)1)\n12: (cid:18)\n13: (cid:3)(cid:18)(cid:3) (K(cid:18)(cid:3) (X; X) + (cid:6))\n(cid:0)1\n14: ^T (x) (KT\n(cid:18)(cid:3) (x) (cid:3)(cid:18)(cid:3) Y)T e\n15: V(x) K(cid:18)(cid:3) (x; x) (cid:0) K(cid:18)(cid:3) (x) (cid:3)(cid:18)(cid:3) KT\n16: ^I(x) erf\n(cid:0)1((cid:13)) (2eT V(x)e)\n17: C(cid:13)(x) [ ^T (x) (cid:0) ^I(x); ^T (x) + ^I(x)]\n\n1); ^vt+1 vt=(1 (cid:0) (cid:12)t\n2)\n^vt+1 + \u03f5)\n\n((cid:0)(cid:17) (cid:1) ^mt+1=(\n\n(cid:18)(cid:3) (x)\n\n)\n\np\n\n1\n2\n\nSince the ground truth counterfactual outcomes are never available in real-world observational\ndatasets, evaluating causal inference algorithms is not straightforward. We follow the semi-synthetic\nexperimental setup in [5, 6, 8], where covariates and treatment assignments are real but outcomes\nare simulated. Experiments are conducted using the IHDP dataset introduced in [5]. We also\nintroduce a new experimental setup using the UNOS dataset: an observational dataset involving\nend-stage cardiovascular patients wait-listed for heart transplantation. Finally, we illustrate the clin-\nical utility and signi\ufb01cance of our algorithm by applying it to the real outcomes in the UNOS dataset.\n\nThe IHDP dataset The Infant Health and Development Program (IHDP) is intended to enhance\nthe cognitive and health status of low birth weight, premature infants through pediatric follow-ups\nand parent support groups [5]. The semi-simulated dataset in [5, 6, 8] is based on covariates from\na real randomized experiment that evaluated the impact of the IHDP on the subjects\u2019 IQ scores at\nthe age of three: selection bias is introduced by removing a subset of the treated population. All\noutcomes (response surfaces) are simulated. The response surface data generation process was not\ndesigned to favor our method: we used the standard non-linear \"Response Surface B\" setting in [5]\n\n6\n\n\f(also used in [6] and [8]). The dataset comprises 747 subjects (608 control and 139 treated), and\nthere are 25 covariates associated with each subject.\n\nThe UNOS dataset3 The United Network for Organ Sharing (UNOS) dataset contains information\non every heart transplantation event in the U.S. since 1987. The dataset also contains information on\npatients registered in the heart transplantation wait-list over the years, including those who died be-\nfore undergoing a transplant. Left Ventricular Assistance Devices (LVADs) were introduced in 2001\nas a life-saving therapy for patients awaiting a heart donor [26]; the survival bene\ufb01ts of LVADs are\nvery heterogeneous across the patients\u2019 population, and it is unclear to practitioners how outcomes\nvary across patient subgroups. It is important to learn the heterogeneous survival bene\ufb01ts of LVADs\nin order to appropriately re-design the current transplant priority allocation scheme [26].\nWe extracted a cohort of patients enrolled in the wait-list in 2010; we chose this year since by that\ntime the current continuous-\ufb02ow LVAD technology became dominant in practice, and patients have\nbeen followed up suf\ufb01ciently long to assess their survival. (Details of data processing is provided in\nAppendix C.) After excluding pediatric patients, the cohort comprised 1,006 patients (774 control\nand 232 treated), and there were 14 covariates associated with each patient. The outcomes (survival\ntimes) generation model is described as follows: (cid:27)0 = (cid:27)1 = 1, f0(x) = exp((x + 1\n2 ) \u2126); and\nf1(x) = \u2126 x (cid:0) !; where \u2126 is a random vector of regression coef\ufb01cients sampled uniformly from\n[0; 0:1; 0:2; 0:3; 0:4], and ! is selected for a given \u2126 so as to adjust the average survival bene\ufb01t to 5\nyears. In order to increase the selection bias, we estimate the propensity score P(Wi = 1jXi = x)\nusing logistic-regression, and then, sequentially, with probability 0.5 we remove the control patient\nwhose propensity score is closest to 1, and with probability 0.5 we remove a random control patient.\nA total of 200 patients are removed, leading to a cohort with 806 patients. The resulting dataset is\nmore biased than IHDP, and hence poses a greater inferential challenge.\n\nTable 1: Results on the IHDP and UNOS datasets (lower\n\nIHDP\n\np\nIn-sample\nPEHE\n0.59 (cid:6) 0.01\n2.1 (cid:6) 0.11\n2.0 (cid:6) 0.13\n2.4 (cid:6) 0.21\n1.4 (cid:6) 0.07\n2.7 (cid:6) 0.24\n5.9 (cid:6) 0.31\n2.1 (cid:6) 0.11\n1.0 (cid:6) 0.07\n3.2 (cid:6) 0.12\n4.9 (cid:6) 0.31\n5.2 (cid:6) 0.35\n\np\n\nOut-of-sample\nPEHE\n0.76 (cid:6) 0.01\n2.3 (cid:6) 0.14\n2.2 (cid:6) 0.17\n2.8 (cid:6) 0.23\n2.2 (cid:6) 0.16\n2.9 (cid:6) 0.25\n6.1 (cid:6) 0.41\n2.2 (cid:6) 0.13\n1.2 (cid:6) 0.08\n4.2 (cid:6) 0.22\n4.9 (cid:6) 0.31\n5.2 (cid:6) 0.35\n\n\u2661 CMGP\nGP\n| BART\nCF\n\n(cid:127)\n\nVTRF\nCFRF\nBLR\nBNN\nCFRW\n\u22c6 kNN\nPSM\n\u2662\nTML\n\np\n\nPEHE is better).\n\nUNOS\n\np\nIn-sample\nPEHE\n1.7 (cid:6) 0.10\n4.1 (cid:6) 0.15\n3.5 (cid:6) 0.17\n3.8 (cid:6) 0.25\n4.5 (cid:6) 0.35\n4.7 (cid:6) 0.21\n5.7 (cid:6) 0.21\n3.2 (cid:6) 0.10\n2.7 (cid:6) 0.07\n5.2 (cid:6) 0.11\n4.6 (cid:6) 0.12\n6.2 (cid:6) 0.31\n\np\nOut-of-sample\nPEHE\n1.8 (cid:6) 0.13\n4.5 (cid:6) 0.20\n3.9 (cid:6) 0.23\n4.3 (cid:6) 0.31\n4.9 (cid:6) 0.41\n5.2 (cid:6) 0.32\n6.2 (cid:6) 0.30\n3.3 (cid:6) 0.12\n2.9 (cid:6) 0.11\n5.4 (cid:6) 0.12\n4.8 (cid:6) 0.16\n6.2 (cid:6) 0.31\n\nBenchmarks We compare our algorithm with: | Tree-based methods (BART [5], causal forests\n(CF) [4, 9], virtual-twin random forests (VTRF) [7], and counterfactual random forests (CFRF) [7]),\n(cid:127) Balancing counterfactual regression (Balancing linear regression (BLR) [6], balancing neural\nnetworks (BNN) [6], and counterfactual regression with Wasserstein distance metric (CFRW) [8]),\n\u22c6 Propensity-based and matching methods (k nearest-neighbor (kNN), propensity score matching\n(PSM)), \u2662 Doubly-robust methods (Targeted maximum likelihood (TML) [22]), and \u2661 Gaussian\nprocess-based methods (separate GP regression for treated and control with marginal likelihood\nmaximization (GP)). Details of all these benchmarks are provided in Appendix D.\nFollowing [4-9], we evaluate the performance of all algorithms by reporting the square-root of\njXi = x])2, where f1(Xi) (cid:0) f0(Xi) is\nPEHE = 1\nn\n\ni=1((f1(Xi) (cid:0) f0(Xi)) (cid:0) E[Y (1)\n\n(cid:0) Y (0)\n\n\u2211\n\ni\n\ni\n\nn\n\n3DJJFI\u0003\u0002\u0002MMM\u0002K\u0006\u0006I\u0002\u0006HC\u0002@=J=\u0002\n\n7\n\n\fthe estimated treatment effect. We evaluate the PEHE via a Monte Carlo simulation with 1000\nrealizations of both the IHDP and UNOS datasets, where in each experiment we run all the\nbenchmarks with 60/20/20 train-validation-test splits. Counterfactuals are never made available to\nany of the benchmarks. We run Algorithm 1 with the a learning rate of 0.01 and with the standard\n(cid:0)8). We report both the in-sample and\nsetting prescribed in [21] (i.e. (cid:12)1 = 0:9; (cid:12)2 = 0:999; \u03f5 = 10\nout-of-sample PEHE estimates: the former corresponds to the accuracy of the estimated ITE in a\nretrospective cohort study, whereas the latter corresponds to the performance of a clinical decision\nsupport system that provides out-of-sample patients with ITE estimates [8]. The in-sample PEHE\nmetrics is non-trivial since we never observe counterfactuals even in the training phase.\n\nResults As can be seen in Table 1, CMGPs outperform all other benchmarks in terms of the PEHE\nin both the IHDP and UNOS datasets. The bene\ufb01t of the risk-based empirical Bayes method man-\nifest in the comparison with ordinary GP regression that \ufb01ts the treated and control populations by\nevidence maximization. The performance gain of CMGPs with respect to GPs increase in the UNOS\ndataset as it exhibits a larger selection bias, hence na\u00efve GP regression tends to \ufb01t a function to the\nfactual outcomes that does not generalize well to counterfactuals. Our algorithm is also performing\nbetter than all other nonparametric tree-based algorithms. In comparison to BART, our algorithm\nplaces an adaptive prior on a smooth function space, and hence it is capable of achieving faster\nposterior contraction rates than BART, which places a prior on a space of discontinuous functions\n[16]. Similar insights apply to the frequentist random forest algorithms. CMGPs also outperform\nthe different variants of counterfactual regression in both datasets, though CFRW is competitive in\nthe IHDP experiment. BLR performs badly in both datasets as it balances the distributions of the\ntreated and control populations by variable selection, and hence it throws away informative features\nfor the sake of balancing the selection bias. The performance gain of CMGPs with respect to BNN\nand CFRW shows that the multi-task learning framework is advantageous: through the linear core-\ngionalization kernel, CMGPs preserves the interactions between Wi and Xi, and hence is capable\nof capturing highly non-linear (heterogeneous) response surfaces.\n\nFigure 2: Pathway for a representative patient in the UNOS dataset.\n\n6 Discussion: Towards Precision Medicine\n\nTo provide insights into the clinical utility of CMGPs, we ran our algorithm on all patients in the\nUNOS dataset who were wait-listed in the period 2005-2010, and used the real patient survival times\nas outcomes. The current transplant priority allocation scheme relies on a coarse categorization of\npatients that does not take into account their individual risks; for instance, all patients who have\nan LVAD are thought of as bene\ufb01ting from it equally. We found a substantial evidence in the data\nthat this leads to wrong clinical decision. In particular, we found that 10.3% of wait-list patients\nfor whom an LVAD was implanted exhibit a delayed assignment to a high priority allocation in the\nwait-list. One of such patients has her pathway depicted in Fig. 2: she was assigned a high priority\n(status 1A) in June 2013, but died shortly after, before her turn to get a heart transplant. Her late\nassignment to the high priority status was caused by an overestimated bene\ufb01t of the LVAD she got\nimplanted in 2010; that is, the wait-list allocation scheme assumed she will attain the \"populational\naverage\" survival bene\ufb01t from the LVAD. Our algorithm had a much more conservative estimate\nof her survival; since she was diabetic, her individual bene\ufb01t from the LVAD was less than the\npopulational average. We envision a new priority allocation scheme in which our algorithm is used\nto allocate priorities based on the individual risks in a personalized manner.\n\n8\n\n\fReferences\n\n[1] C. Adams and V. Brantner. Spending on New Drug Development. Health Economics, 19(2): 130-141,\n2010.\n[2] J. C. Foster, M. G. T. Jeremy, and S. J. Ruberg. Subgroup Identi\ufb01cation from Randomized Clinical Trial\nData. Statistics in medicine, 30(24), 2867-2880, 2011.\n[3] W. Sauerbrei, M. Abrahamowicz, D. G. Altman, S. Cessie, and J. Carpenter. Strengthening Analytical\nThinking for Observational Studies: the STRATOS Initiative. Statistics in medicine, 33(30): 5413-5432, 2014.\n[4] S. Athey and G. Imbens. Recursive Partitioning for Heterogeneous Causal Effects. Proceedings of the\nNational Academy of Sciences, 113(27):7353-7360, 2016.\n[5] J. L. Hill. Bayesian Nonparametric Modeling for Causal Inference. Journal of Computational and Graphi-\ncal Statistics, 2012.\n[6] F. D. Johansson, U. Shalit, and D. Sontag. Learning Representations for Counter-factual Inference. In\nICML, 2016.\n[7] M. Lu, S. Sadiq, D. J. Feaster, and H. Ishwaran. Estimating Individual Treatment Effect in Observational\nData using Random Forest Methods. arXiv:1701.05306, 2017.\n[8] U. Shalit, F. Johansson, and D. Sontag. Estimating Individual Treatment Effect: Generalization Bounds and\nAlgorithms. arXiv:1606.03976, 2016.\n[9] S. Wager and S. Athey. Estimation and Inference of Heterogeneous Treatment Effects using Random Forests.\narXiv:1510.04342, 2015.\n[10] Y. Xie, J. E. Brand, and B. Jann. Estimating Heterogeneous Treatment Effects with Observational Data.\nSociological Methodology, 42(1):314-347, 2012.\n[11] Y. Xu, Y. Xu, and S. Saria. A Bayesian Nonparametic Approach for Estimating Individualized Treatment-\nResponse Curves. arXiv:1608.05182, 2016.\n[12] M. Dudk, J. Langford, and L. Li. Doubly robust policy evaluation and learning. In ICML, 2011.\n[13] A. Swaminathan and T. Joachims. Batch Learning from Logged Bandit Feedback Through Counter-factual\nRisk Minimization. Journal of Machine Learning Research, 16(1): 1731-1755, 2015.\n[14] A. Abadie and G. Imbens. Matching on the Estimated Propensity Score. Econometrica, 84(2):781-807,\n2016.\n[15] M. A. Alvarez, L. Rosasco, N. D. Lawrence. Kernels for Vector-valued Functions: A Review. Foundations\nand Trends R\u20ddin Machine Learning, 4(3):195-266, 2012.\n[16] S. Sniekers, A. van der Vaart. Adaptive Bayesian Credible Sets in Regression with a Gaussian Process\nPrior. Electronic Journal of Statistics, 9(2):2475-2527, 2015.\n[17] B. Schlkopf, R. Herbrich, and A. J. Smola. A Generalized Representer Theorem. International Conference\non Computational Learning Theory, 2001.\n[18] E. V. Bonilla, K. M. Chai, and C. Williams. Multi-task Gaussian Process Prediction. In NIPS, 2007.\n[19] S. Bickel, M. Brckner, and T. Scheffer. Discriminative Learning under Covariate Shift. Journal of Machine\nLearning Research, 10(9): 2137-2155, 2009.\n[20] V. Chernozhukov, D. Chetverikov, M. Demirer, E. Du\ufb02o, and C. Hansen. Double Machine Learning for\nTreatment and Causal Parameters. arXiv preprint arXiv:1608.00060, 2016.\n[21] D. Kingma and J. Ba. ADAM: A Method for Stochastic Optimization. arXiv:1412.6980, 2014.\n[22] K. E. Porter, S. Gruber, M. J. Van Der Laan, and J. S. Sekhon. The Relative Performance of Targeted\nMaximum Likelihood Estimators. The International Journal of Biostatistics, 7(1):1-34, 2011.\n[23] Carl Edward Rasmussen. Gaussian Processes for Machine Learning. Citeseer, 2006.\n[24] G. C. Cawley and N. L. C. Talbot. Preventing Over-\ufb01tting During Model Selection via Bayesian Regulari-\nsation of the Hyper-parameters. Journal of Machine Learning Research, 841-861, 2007.\n[25] J. Hartford, G. Lewis, K. Leyton-Brown, and M. Taddy. Counterfactual Prediction with Deep Instrumental\nVariables Networks. arXiv preprint arXiv:1612.09596, 2016.\n[26] M. S. Slaughter, et al. Advanced Heart Failure Treated with Continuous-\ufb02ow Left Ventricular Assist\nDevice. New England Journal of Medicine, 361(23): 2241-2251, 2009.\n\n9\n\n\f", "award": [], "sourceid": 1951, "authors": [{"given_name": "Ahmed", "family_name": "Alaa", "institution": ""}, {"given_name": "Mihaela", "family_name": "van der Schaar", "institution": ""}]}