{"title": "Adapting Neural Networks for the Estimation of Treatment Effects", "book": "Advances in Neural Information Processing Systems", "page_first": 2507, "page_last": 2517, "abstract": "This paper addresses the use of neural networks for the estimation of treatment effects from observational data. Generally, estimation proceeds in two stages. First, we \ufb01t models for the expected outcome and the probability of treatment (propensity score). Second, we plug these \ufb01tted models into a downstream estimator. Neural networks are a natural choice for the models in the \ufb01rst step. Our question is: how can we adapt the design and training of the neural networks used in this \ufb01rst step in order to improve the quality of the \ufb01nal estimate of the treatment effect? We propose two adaptations based on insights from the statistical literature on the estimation of treatment effects. The \ufb01rst is a new architecture, the Dragonnet, that exploits the suf\ufb01ciency of the propensity score for estimation adjustment. The second is a regularization procedure, targeted regularization, that induces a bias towards models that have non-parametrically optimal asymptotic properties \u2018out-of-the-box\u2019. Studies on benchmark datasets for causal inference show these adaptations outperform existing methods.", "full_text": "Adapting Neural Networks for the Estimation of\n\nTreatment Effects\n\nClaudia Shi1, David M. Blei1,2, and Victor Veitch2\n\n1Department of Computer Science, Columbia Unitversity\n\n2Department of Statistics, Columbia University\n\nAbstract\n\nThis paper addresses the use of neural networks for the estimation of treatment\neffects from observational data. Generally, estimation proceeds in two stages. First,\nwe \ufb01t models for the expected outcome and the probability of treatment (propensity\nscore) for each unit. Second, we plug these \ufb01tted models into a downstream\nestimator of the effect. Neural networks are a natural choice for the models in the\n\ufb01rst step. The question we address is: how can we adapt the design and training of\nthe neural networks used in the \ufb01rst step in order to improve the quality of the \ufb01nal\nestimate of the treatment effect? We propose two adaptations based on insights\nfrom the statistical literature on the estimation of treatment effects. The \ufb01rst is\na new architecture, the Dragonnet, that exploits the suf\ufb01ciency of the propensity\nscore for estimation adjustment. The second is a regularization procedure, targeted\nregularization, that induces a bias towards models that have non-parametrically\noptimal asymptotic properties \u2018out-of-the-box\u2019. Studies on benchmark datasets\nfor causal inference show these adaptations outperform existing methods. Code is\navailable at github.com/claudiashi57/dragonnet.\n\n1\n\nIntroduction\n\nWe consider the estimation of causal effects from observational data. Observational data is often\nreadily available in situations where randomized control trials (RCT) are expensive or impossible.\nHowever, causal inference from observational data must address (possible) confounding factors that\naffect both treatment and outcome. Failure to adjust for confounders can lead to incorrect conclusions.\nTo address this, a practitioner collects covariate information in addition to treatment and outcome\nstatus. The causal effect can be identi\ufb01ed if the covariates contain all confounding variables. We\nwill work in this \u2018no hidden confounding\u2019 setting throughout the paper. The task we consider is the\nestimation of the effect of a treatment T (e.g., a patient receives a drug) on an outcome Y (whether\nthey recover) adjusting for covariates X (e.g., illness severity or socioeconomic status).\n\nWe consider how to use neural networks to estimate the treatment effect. The estimation of treatment\neffects proceeds in two stages. First, we \ufb01t models for the conditional outcome Q(t, x) = E[Y | t, x]\nand the propensity score g(x) = P (T = 1|x). Then, we plug these \ufb01tted models into a downstream\nestimator. The strong predictive performance of neural networks motivates their use for effect\nestimation [e.g. SJS16; JSS16; Lou+17; AS17; AWS17; SLK18; YJS18; FLM18]. We will use neural\nnetworks as models for the conditional outcome and propensity score.\n\nIn principle, using neural networks for the conditional outcome and propensity score models is\nstraightforward. We can use a standard net to predict the outcome Y from the treatment and\ncovariates, and another to predict the treatment from the covariates. With a suitable choice of\ntraining objective, the trained models will yield consistent estimates of the conditional outcomes and\npropensity scores. However, neural network research has focused on predictive performance. What is\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fimportant for causal inference is the quality of the downstream estimation. This leads to our main\nquestion: how can we modify the design and training of neural networks in order to improve the\nquality of treatment effect estimation?\n\nWe address this question by adapting results from the statistical literature on the estimation of\ntreatment effects. The contributions of this paper are:\n\n1. A neural network architecture\u2014the Dragonnet\u2014based on the suf\ufb01ciency of the propensity\n\nscore for causal estimation.\n\n2. A regularization procedure\u2014targeted regularization\u2014based on non-parametric estimation\n\ntheory.\n\n3. An empirical study of these methods on established benchmark datasets. We \ufb01nd the\nmethods substantially improve estimation quality in comparison to existing neural-network\nbased approaches. This holds even when the methods degrade predictive performance.\n\nSetup.\nFor concreteness, we consider the estimation of the average effect of a binary treat-\nment, though the methods apply broadly. The data are generated independently and identically\n\n(Yi, Ti, Xi) iid\u223c P . The average treatment affect (ATE) \u03c8 is\n\n\u03c8 = E[Y | do(T = 1)] \u2212 E[Y | do(T = 0)].\n\nThe use of Pearl\u2019s do notation indicates that the effect of interest is causal. It corresponds to what\nhappens if we intervene by assigning a new patient the drug. If the observed covariates X include all\ncommon causes of the treatment and outcome\u2014i.e., block all backdoor paths\u2014then the causal effect\nis equal to a parameter of the observational distribution P ,\n\n\u03c8 = E[E[Y | X, T = 1] \u2212 E[Y | X, T = 0]].\n\nWe want to estimate \u03c8 using a \ufb01nite sample from P . Following equation 1.1, an estimator is\n\n\u02c6\u03c8Q =\n\n1\n\nn Xi h \u02c6Q(1, xi) \u2212 \u02c6Q(0, xi)i ,\n\n(1.1)\n\n(1.2)\n\nwhere \u02c6Q is an estimate of the conditional outcome Q(t, x) = E[Y | t, x]. There are also more\nsophisticated estimators that additionally rely on estimates \u02c6g of the propensity score g(x) = P(T =\n1 | x); see section 3.\nWe now state our question of interest plainly. We want to use neural networks to model Q and g.\nHow should we adapt the design and training of these networks so that \u02c6\u03c8 is a good estimate of \u03c8?\n\n2 Dragonnet\n\nOur starting point is a classic result, [RR83, Thm. 3],\n\nTheorem 2.1 (Suf\ufb01ciency of Propensity Score). If the average treatment effect \u03c8 is identi\ufb01able\nfrom observational data by adjusting for X, i.e., \u03c8 = E[E[Y | X, T = 1] \u2212 E[Y | X, T = 0]], then\nadjusting for the propensity score also suf\ufb01ces:\n\n\u03c8 = E[E[Y | g(X), T = 1] \u2212 E[Y | g(X), T = 0]]\n\nIn words: it suf\ufb01ces to adjust for only the information in X that is relevant for predicting the treatment.\nConsider the parts of X that are relevant for predicting the outcome but not the treatment. Those\nparts are irrelevant for the estimation of the causal effect, and are effectively noise for the adjustment.\nAs such, we expect conditioning on these parts to hurt \ufb01nite-sample performance. Instead, we\nshould discard this information.1 For example, when computing the expected outcome estimator \u02c6\u03c8Q,\n(equation 1.2), we should train \u02c6Q to predict Y from only the part of X relevant for T , even though\nthis may degrade the predictive performance of \u02c6Q.\n\nHere is one way to use neural networks to \ufb01nd the relevant parts of X. First, train a deep net to\npredict T . Then remove the \ufb01nal (predictive) layer. Finally, use the activation of the remaining net\n\n1A caveat: this intuition applies to the complexity of learning the outcome model. In the case of high-variance\n\noutcome and small sample size modeling all covariates may help as a variance reduction technique.\n\n2\n\n\fas features for predicting the outcome. In other contexts (e.g., images) this is a standard procedure\n[e.g., Gir+14]. The hope is that the \ufb01rst net will distill the covariates into the features relevant for\ntreatment prediction, i.e., relevant to the propensity score \u02c6g. Then, conditioning on the features\nis equivalent to conditioning on the propensity score itself. However, this process is cumbersome.\nWith \ufb01nite data, estimation errors in the propensity score model \u02c6g may propagate to the conditional\noutcome model. Ideally, the model itself should choose a tradeoff between predictive accuracy and\nthe propensity-score representation.\n\nThis method inspires Dragonnet,2 a three-headed architec-\nture that provides an end-to-end procedure for predicting\npropensity score and conditional outcome from covariates\nand treatment information. See Figure 1. We use a deep\n\nnet to produce a representation layer Z(X) \u2208 Rp, and then\n\npredict both the treatment and outcome from this shared\nrepresentation. We use 2-hidden layer neural networks\n\nfor each of the outcome models \u02c6Q(0,\u00b7) : Rp \u2192 R and\n\u02c6Q(1,\u00b7) : Rp \u2192 R. In contrast, we use a simple linear map\n(followed by a sigmoid) for the propensity score model \u02c6g.\nThe simple map forces the representation layer to tightly\ncouple to the estimated propensity scores.\n\nFigure 1: Dragonnet architecture.\n\nDragonnet has parameters\nQnn(ti, xi; \u03b8) and gnn(xi; \u03b8). We train the model by minimizing an objective function,\n\nand output heads\n\n\u03b8\n\n\u02c6\u03b8 = argmin\n\n\u02c6R(\u03b8; X), where\n\n\u03b8\n\n\u02c6R(\u03b8; X) =\n\n1\n\nn Xi (cid:2)(Qnn(ti, xi; \u03b8) \u2212 yi)2 + \u03b1CrossEntropy(gnn(xi; \u03b8), ti)(cid:3),\n\n(2.1)\n\n(2.2)\n\nwhere \u03b1 \u2208 R+ is a hyperparameter weighting the loss components. The \ufb01tted model is \u02c6Q =\nQnn(\u00b7,\u00b7; \u02c6\u03b8) and \u02c6g = gnn(\u00b7; \u02c6\u03b8). With the \ufb01tted outcome model \u02c6Q in hand, we can estimate the\ntreatment effect with the estimator \u02c6\u03c8Q (equation 1.2).\n\nIn principle, the end-to-end training and high capacity of Dragonnet might allow it to avoid throwing\naway any information. In section 5, we study the Dragonnet\u2019s behaviour empirically and \ufb01nd evidence\nthat it does indeed trade off prediction quality to achieve a good representation of the propensity\nscore. Further, this trade-off improves ATE estimation even when we use a downstream estimator,\nsuch as \u02c6\u03c8Q, that does not use the estimated propensity scores.\n\nIf the propensity-score head is removed from Dragonnet, the resulting architecture is (essentially) the\nTARNET architecture from Shalit et al. [SJS16]. We compare with TARNET in section 5. We also\ncompare to the multiple-stage method described above.\n\n3 Targeted Regularization\n\nWe now turn to targeted regularization, a modi\ufb01cation to the objective function used for neural\nnetwork training. This modi\ufb01ed objective is based on non-parametric estimation theory. It yields\na \ufb01tted model that, with a suitable downstream estimator, guarantees desirable asymptotic properties.\n\nWe review some necessary results from semi-parametric estimation theory, and then explain targeted\nregularization. The summary of this section is:\n\n1. \u02c6\u03c8 has good asymptotic properties if it satis\ufb01es a certain equation (equation 3.1) with \u02c6Q and\n\n\u02c6g.\n\n2. Targeted regularization (equation 3.2) is a modi\ufb01cation to the training objective.\n3. Minimizing this objective forces ( \u02c6Qtreg, \u02c6g, \u02c6\u03c8treg) to satisfy the required equation, where\n\n\u02c6Qtreg and \u02c6\u03c8treg are particular choices for \u02c6Q and \u02c6\u03c8.\n\nSetup. Recall that the general recipe for estimating a treatment effect has two steps: (i) \ufb01t models\nfor the conditional outcome Q and the propensity score g; (ii) plug the \ufb01tted models \u02c6Q and \u02c6g into\n\n2\u201cDragonnet\u201d because the dragon has three heads.\n\n3\n\n\fa downstream estimator \u02c6\u03c8. The estimator \u02c6\u03c8Q in equation 1.2 is the simplest example. There are a\nwealth of alternatives that, in theory, offer better performance.\n\nSuch estimators are studied in the semi-parametric estimation literature; see Kennedy [Ken16] for\na readable introduction. We restrict ourselves to the (simpler) fully non-parametric case; i.e., we\nmake no assumptions on the form of the true data generating distribution. For our purposes, the key\nresults from non-parametric theory are of the form: If the tuple ( \u02c6Q, \u02c6g, \u02c6\u03c8) satis\ufb01es a certain equation,\n(equation 3.1 below), then, asymptotically, the estimator \u02c6\u03c8 will have various good properties. For\ninstance,\n\n1. robustness in the double machine-learning sense [Che+17a; Che+17b]\u2014 \u02c6\u03c8 converges to \u03c8\n\nat a fast rate (in the sample complexity sense) even if \u02c6Q and \u02c6g converge slowly; and\n\n2. ef\ufb01ciency\u2014asymptotically, \u02c6\u03c8 has the lowest variance of any consistent estimator of \u03c8. That\n\nis, the estimator \u02c6\u03c8 is asymptotically the most data ef\ufb01cient estimator possible.\n\nThese asymptotic guarantees hold if (i) \u02c6Q and \u02c6g are consistent estimators for the conditional outcome\nand propensity scores, and (ii) the tuple satis\ufb01es the non-parametric estimating equation,\n\n0 =\n\n1\n\nn Xi\n\n\u03d5(yi, ti, xi; \u02c6Q, \u02c6g, \u02c6\u03c8),\n\n(3.1)\n\nwhere \u03d5 is the ef\ufb01cient in\ufb02uence curve of \u03c8,\n\n\u03d5(y, t, x; Q, g, \u03c8) = Q(1, x) \u2212 Q(0, x) +(cid:18) t\n\ng(x) \u2212\n\n1 \u2212 g(x)(cid:19){y \u2212 Q(t, x)} \u2212 \u03c8.\n1 \u2212 t\n\nSee, e.g., Chernozhukov et al. [Che+17b] and van der Laan and Rose [vR11] for details.\nA natural way to construct a tuple satisfying the non-parametric estimating equation is to estimate \u02c6Q\nand \u02c6g in a manner agnostic to the downstream estimation task, and then choose \u02c6\u03c8 so that equation 3.1\nis satis\ufb01ed. This yields the A-IPTW estimator [RRL00; Rob00]. Unfortunately, the presence of \u02c6g\nin the denominator of some terms can cause the A-IPTW be unstable in \ufb01nite samples, despite its\nasymptotic optimality. (In our experiments, the A-IPTW estimator consistently under-performs the\nnaive estimator \u02c6\u03c8Q.)\n\nTargeted minimum loss estimation (TMLE) [vR11] is an alternative strategy that mitigates the \ufb01nite-\nsample instability. The TMLE relies on (task-agnostic) \ufb01tted models \u02c6Q and \u02c6g. The idea is to perturb\nthe estimate \u02c6Q\u2014with perturbation depending on \u02c6g\u2014such that the simple estimator \u02c6\u03c8Q satis\ufb01es the\nnon-parametric estimating equation (equation 3.1). Because the simple estimator is free of \u02c6g in\ndenominators, it is stable with \ufb01nite data. Thus, the TMLE yields an estimate that has both good\nasymptotic properties and good \ufb01nite-sample performance. The ideas that underpin TMLE are the\nmain inspiration for targeted regularization.\n\nTargeted regularization. We now describe targeted regularization. We require Q and g to be\nmodeled by a neural network (such as Dragonnet) with output heads Qnn(ti, xi; \u03b8) and gnn(xi; \u03b8).\nBy default, the neural network is trained by minimizing a differentiable objective function \u02c6R(\u03b8; X),\ne.g., equation 2.2.\n\nTargeted regularization is a modi\ufb01cation to the objective function. We introduce an extra model\nparameter \u03b5 and a regularization term \u03b3(y, t, x; \u03b8, \u03b5) de\ufb01ned by\n\n\u02dcQ(ti, xi; \u03b8, \u03b5) = Qnn(ti, xi; \u03b8) + \u03b5h\n\u03b3(yi, ti, xi; \u03b8, \u03b5) = (yi \u2212 \u02dcQ(ti, xi; \u03b8, \u03b5))2.\n\nti\n\ngnn(xi; \u03b8) \u2212\n\n1 \u2212 ti\n\n1 \u2212 gnn(xi; \u03b8)i\n\nWe then train the model by minimizing the modi\ufb01ed objective,\n\n\u02c6\u03b8, \u02c6\u03b5 = argmin\n\n\u03b8,\u03b5\n\n(cid:2) \u02c6R(\u03b8; X) + \u03b2\n\n1\n\nn Xi\n\n\u03b3(yi, ti, xi; \u03b8, \u03b5)(cid:3).\n\nThe variable \u03b2 \u2208 R+ is a hyperparameter. Next, we de\ufb01ne an estimator \u02c6\u03c8treg as:\n\n\u02c6\u03c8treg =\n\n1\n\nn Xi\n\n\u02c6Qtreg(1, xi) \u2212 \u02c6Qtreg(0, xi), where\n\n\u02c6Qtreg = \u02dcQ(\u00b7,\u00b7; \u02c6\u03b8, \u02c6\u03b5).\n\n4\n\n(3.2)\n\n(3.3)\n\n(3.4)\n\n\fThe key observation is\n\n0 = \u2202\u03b5(cid:0) \u02c6R(\u03b8; X) + \u03b2\n\n1\n\nn Xi\n\n\u03b3(yi, ti, xi; \u03b8, \u03b5)(cid:1)(cid:12)(cid:12)\u02c6\u03b5 = \u03b2\n\n1\n\nn X \u03d5(yi, ti, xi; \u02c6Qtreg, \u02c6g, \u02c6\u03c8treg).\n\n(3.5)\n\nThat is, minimizing the targeted regularization term forces \u02c6Qtreg, \u02c6g, \u02c6\u03c8treg to satisfy the non-parametric\nestimating equation equation 3.1.\nAccordingly, the estimator \u02c6\u03c8treg will have the good non-parametric asymptotic properties so long\nas \u02c6Qtreg and \u02c6g are consistent. Consistency is plausible\u2014even with the addition of the targeted\nregularization term\u2014because the model can choose to set \u03b5 to 0, which (essentially) recovers the\noriginal training objective. For instance, if \u02c6Q and \u02c6g are consistent in the original model than the\ntargeted regularization estimates will also be consistent. In detail, the targeted regularization model\npreserves \ufb01nite VC dimension (we add only 1 parameter), so the limiting model is an argmin of\nthe true (population) risk. The true risk for the targeted regularization loss has a minimum at\n\nthese values (by consistency), and the targeted regularization term (a squared error) is minimized at\n\n\u02c6Q = E[Y |x, t], \u02c6g = P (T = 1|x), and \u02c6\u03b5 = 0. This is because the original risk is minimized at\n\u02c6Q + \u02c6\u03b5H(\u02c6g) = E[Y |x, t], which is achieved at \u02c6\u03b5 = 0.\nThe key idea, equation 3.5, is inspired by TMLE. Like targeted regularization, TMLE introduces\nan extra model parameter \u03b5. It then chooses \u02c6\u03b5 so that a \u02c6\u03b5-perturbation of \u02c6Q satis\ufb01es equation 3.1\nwith \u02c6\u03c8Q. However, TMLE uses only the parameter \u03b5 to ensure that the non-parametric estimating\nequation are satis\ufb01ed, while targeted regularization adapts the entire model. Both TMLE and\ntargeted regularization are designed to yield an estimate with stable \ufb01nite-sample behavior and strong\nasymptotic guarantees. We compare these methods in section 5.\n\nWe note that estimators satisfying the non-parametric estimating equation are also \u2018doubly robust\u2019,\nthat is the effect estimate is consistent if either \u02c6Q or \u02c6g is consistent. This property also holds for\nthe targeted regularization estimator, if either the \u02c6Q or \u02c6g is consistent.\n\n4 Related Work\n\nThe methods connect to different areas in causal inference and estimation theory.\n\nRepresentations for causal inference. Dragonnet is related to papers using representation learning\nideas for treatment effect estimation. The Dragonnet architecture resembles TARNET, a two-headed\noutcome-only model used as the baseline in Shalit et al. [SJS16]. One approach in the literature\nemphasizes learning a covariate representation that has a balanced distribution across treatment\nand outcome; e.g., BNNs [JSS16] and CFRNET [SJS16]. Other work combines deep generative\nmodels with standard causal identi\ufb01cation results. CEVEA [Lou+17], GANITE [YJS18], and\nCMPGP [AS17] use VAEs, GANs, and multi-task gaussian processes, respectively, to estimate\ntreatment effects. Another approach combines (pre-trained) propensity scores with neural networks;\ne.g., Propensity Dropout [AWS17] and Perfect Matching [SLK18]. Dragonnet complements these\napproaches. Exploiting the suf\ufb01ciency of the propensity score is a distinct approach, and it may be\npossible to combine it with other strategies.\n\nNon-parametric estimation and machine learning. Targeted regularization relates to a body of\nwork combining machine learning methods with semi-parametric estimation theory. As mentioned\nabove, the main inspiration for the method is targeted minimum loss estimation [vR11]. Chernozhukov\net al. [Che+17a; Che+17b] develop theory for \u2018double machine learning\u2019, showing that if certain\nestimating equations are satis\ufb01ed then treatment estimates will converge at a parametric (O(1/\u221an))\nrate even if the conditional outcome and propensity models converge much more slowly. Farrell et al.\n[FLM18] prove that neural networks converge at a fast enough rate to invoke the double machine\nlearning results. This gives theoretical justi\ufb01cation for the use of neural networks to model propensity\nscores and conditional expected outcomes. Targeted regularization is complementary: we rely on the\nasymptotic results for motivation, and address the \ufb01nite-sample approach.\n\n5\n\n\f5 Experiments\n\nDo Dragonnet and targeted regularization improve treatment effect estimation in practice? Dragonnet\nis a high-capacity model trained end-to-end: does it actually discard information irrelevant to the\npropensity score? TMLE already offers an approach for balancing asymptotic guarantees with \ufb01nite\nsample performance: does targeted regularization improve over this?\n\nWe study the methods empirically using two semi-synthetic benchmarking tools. We \ufb01nd that\nDragonnet and targeted regularization substantially improve estimation quality. Moreover, we \ufb01nd that\nDragonnet exploits propensity score suf\ufb01ciency, and that targeted regularization improves on TMLE.\n\nTable 1: Dragonnet with targeted regularization is state-of-the-art among neural network methods on the\nIHDP benchmark data. Entries are mean absolute error (and standard error) across simulations. Estimators\nare computed with the training and validation data (\u2206in), heldout data (\u2206out), and all data (\u2206all). Note that\nusing all the data for both training and estimation improves estimation relative to data splitting. Values from\nprevious work are as reported in the cited papers.\n\nMethod\n\nBNN [JSS16]\nTARNET [SJS16]\nCFR Wass[SJS16]\nCEVAEs [Lou+17]\nGANITE [YJS18]\n\nbaseline (TARNET)\nbaseline + t-reg\nDragonnet\nDragonnet + t-reg\n\n\u2206in\n0.37 \u00b1 .03\n0.26 \u00b1 .01\n0.25 \u00b1 .01\n0.34 \u00b1 .01\n0.43 \u00b1 .05\n0.16 \u00b1 .01\n0.15 \u00b1 .01\n0.14 \u00b1 .01\n0.14 \u00b1 .01\n\n\u2206out\n\u2206all\n0.42 \u00b1 .03 \u2014\n0.28 \u00b1 .01 \u2014\n0.27 \u00b1 .01 \u2014\n0.46 \u00b1 .02 \u2014\n0.49 \u00b1 .05 \u2014\n0.21 \u00b1 .01\n0.20 \u00b1 .01\n0.21 \u00b1 .01\n0.20 \u00b1 .01\n\n0.13 \u00b1 .00\n0.12 \u00b1 .00\n0.12 \u00b1 .00\n0.11 \u00b1 .00\n\n5.1 Setup\n\nGround truth causal effects are rarely available for real-world data. Accordingly, empirical evaluation\nof causal estimation procedures rely on semi-synthetic data. For the conclusions to be useful, the\nsemi-synthetic data must have good \ufb01delity to the real world. We use two pre-established causal\nbenchmarking tools.\n\nIHDP. Hill [Hil11] introduced a semi-synthetic dataset constructed from the Infant Health and\nDevelopment Program (IHDP). This dataset is based on a randomized experiment investigating the\neffect of home visits by specialists on future cognitive scores. Following [SJS16], we use 1000\nrealizations from the NPCI package [Dor16].3 The data has 747 observations.\n\nACIC 2018. We also use the IBM causal inference benchmarking framework, which was developed\nfor the 2018 Atlantic Causal Inference Conference competition data (ACIC 2018) [Shi+18]. This is a\ncollection of semi-synthetic datasets derived from the linked birth and infant death data (LBIDD)\n[MA98]. Importantly, the simulation is comprehensive\u2014including 63 distinct data generating process\nsettings\u2014and the data are relatively large. Each competition dataset is a sample from a distinct\ndistribution, which is itself drawn randomly according to the data generating process setting. For\neach data generating process setting, we randomly pick 3 datasets of size either 5k or 10k.\nSome of the datasets have overlap violations. That is, P(T = 1|x) can be very close to 0 or 1 for\nmany values of x. Although overlap violations are an important area of study, this is not our focus\nand the methods of this paper are not expected to be appropriate in this setting. As a simple heuristic,\nwe exclude all datasets where the heldout treatment accuracy for Dragonnet is higher than 90%; high\nclassi\ufb01cation accuracy indicates a strong separation between the treated and control populations.\nSubject to this criteria, 101 datasets remain.\n\n3There is a typo in Shalit et al. [SJS16]. They use setting A of the NPCI package, which corresponds to\n\nsetting B in Hill [Hil11]\n\n6\n\n\fModel and Baseline Settings. Our main baseline is an implementation of the 2-headed TARNET\narchitecture from Shalit et al. [SJS16]. This model predicts only the outcome, and is equivalent to the\nDragonnet architecture with the propensity head removed.\n\nTable 2: Dragonnet and targeted regularization im-\nprove estimation on average on ACIC 2018. Table\nentries are mean absolute error over all datasets.\n\nTable 3: Dragonnet and targeted regularization im-\nprove over the baseline about half the time, but im-\nprovement is substantial when it does happen. Error\nvalues are mean absolute error on ACIC 2018.\n\nMethod\n\nbaseline (TARNET)\nbaseline + t-reg\nDragonnet\nDragonnet + t-reg\n\n\u2206all\n\n1.45\n1.40\n0.55\n0.35\n\n\u03c8Q\n\nbaseline:\n\n+ t-reg\n+ dragon\n+ dragon & t-reg\n\n%improve\n\n0%\n\n42%\n63%\n46%\n\n\u2191avg\n0\n\n0.30\n1.42\n2.37\n\n\u2193avg\n0\n\n0.11\n0.01\n0.01\n\nFor Dragonnet and targeted regularization, we set the hyperparameters \u03b1 in equation 2.2 and \u03b2 in\nequation 3.2 to 1. For the targeted regularization baseline, we use TARNET as the outcome model\nand logistic regression as the propensity score model. We train TARNET and logistic regression\njointly using the targeted regularization objective.\n\nFor all models, the hidden layer size is 200 for the shared representation layers and 100 for the\nconditional outcome layers. We train using stochastic gradient descent with momentum. Empirically,\nthe choice of optimizer has a signi\ufb01cant impact on estimation performance for the baseline and for\nDragonnet and targeted regularization. Among the optimizers we tried, stochastic gradient descent\nwith momentum resulted in the best performance for the baseline.\n\nFor IHDP experiments, we follow established practice [e.g. SJS16]. We randomly split the data into\ntest/validation/train with proportion 63/27/10 and report the in sample and out of sample estimation\nerrors. However, this procedure is not clearly motivated for parameter estimation, so we also report\nthe estimation errors for using all the data for both training and estimation.\n\nFor the ACIC 2018 experiments, we re-run each estimation procedure 25 times, use all the data for\ntraining and estimation, and report the average estimate errors.\n\nEstimators and metrics.\n\nFor the ACIC experiments, we report mean absolute error of the\n\naverage treatment effect estimate, \u2206 = (cid:12)(cid:12)(cid:12)\n(cid:12)(cid:12)(cid:12)\n\u02c6\u03c8 \u2212 1/nPi Q(1, xi) \u2212 Q(0, xi)(cid:12)(cid:12)(cid:12)\n\n. For IHDP, following established proce-\ndure, we report mean absolute difference between the estimate and the sample ATE, \u2206 =\n. By default, we use \u02c6\u03c8Q as our estimator, except for models\nwith targeted regularization, where we report \u02c6\u03c8treg (equation 3.4). For estimation, we exclude any\ndata point with estimated propensity score outside [0.01, 0.99].\n\n\u02c6\u03c8 \u2212 \u03c8(cid:12)(cid:12)(cid:12)\n\n5.2 Effect on Treatment Estimation\n\nThe IHDP simulation is the de-facto standard benchmark for neural network treatment effect estima-\ntion methods. In table 1 we report the estimation error of a number of approaches. Dragonnet with\ntargeted regularization is state-of-the-art among these methods. However, the small sample size and\nlimited simulation settings of IHDP make it dif\ufb01cult to draw conclusions about the methods. The\nmain takeaways of table 1 are: i) Our baseline method is a strong comparator and ii) reusing the same\ndata for \ufb01tting the model and computing the estimate works better than data splitting.\n\nThe remaining experiments use the Atlantic Causal Inference Conference 2018 competition (ACIC\n2018) dataset. In table 2 we report the mean absolute error over the included datasets. The main\nobservation is that Dragonnet improves estimation relative to the baseline (TARNET), and adding\ntargeted regularization to Dragonnet improves estimation further. Additionally, we observe that\ndespite its asymptotically optimal properties, TMLE hurts more than it helps on average. Double\nrobust estimators such as the TMLE are known to be sensitive to violations of assumptions in other\ncontexts [KS07]. We note that targeted regularization can improve performance even where TMLE\ndoes not.\n\n7\n\n\fIn table 2, we report average estimation error across simulations. We see that Dragonnet and targeted\nregularization improve the baseline estimation. Is this because of small improvement on most datasets\nor major improvement on a subset of datasets? In table 3 we present an alternative comparison. We\ndivide the datasets according to whether each method improves estimation relative to the baseline. We\nreport the average improvement in positive cases, and the degradation in negative ones. We observe\nthat Dragonnet and targeted regularization help about half the time. When the methods do help, the\nimprovement is substantial. When the methods don\u2019t help, the degradation is mild.\n\n5.3 Why does Dragonnet work?\n\nDragonnet was motivated as an end-to-end version of a multi-stage approach. Does the end-to-end\nnetwork work better? We now compare to the multi-stage procedure, which we call NEDnet.4\nNEDnet has essentially the same architecture as Dragonnet. NEDnet is \ufb01rst trained using a pure\ntreatment prediction objective. The \ufb01nal layer (treatment prediction head) is then removed, and\nreplaced with an outcome-prediction neural network matching the one used by Dragonnet. The\nrepresentation layers are then frozen, and the outcome-prediction network is trained on the pure\noutcome prediction task. NEDnet and Dragonnet are compared in table 4. The end-to-end Dragonnet\nproduces more accurate estimates.\n\nTable 4: Dragonnet produces more accurate estimates than NEDnet, a multi-stage alternative. Table\nentries are mean absolute error over all datasets.\n\nIHDP\n\nDragonnet\nNEDnet\n\n\u02c6\u03c8Q\n0.12 \u00b1 0.00\n0.15 \u00b1 0.01\n\n\u02c6\u03c8TMLE\n0.12 \u00b1 0.00\n0.12 \u00b1 0.00\n\nACIC\n\nDragonnet\nNEDnet\n\n\u02c6\u03c8Q\n0.55\n1.49\n\n\u02c6\u03c8TMLE\n1.97\n2.80\n\nWe motivated the Dragonnet architecture by the suf\ufb01ciency of the propensity score for causal\nadjustment. This architecture improves estimation performance. Is this because it is exploiting the\nsuf\ufb01ciency? Three observations suggest this is the case.\n\nFirst, compared to TARNET, Dragonnet has worse performance as a predictor for the outcome, but\nbetter performance as an estimator. See Figure 2. This is the case even when we use the simple\nestimator \u02c6\u03c8Q, which does not use the output of the propensity-score head of Dragonnet. This suggests\nthat, as intended, the shared representation adapts to the treatment prediction task, at the price of\nworse predictive performance for the outcome prediction task.\n\nSecond, Dragonnet is supposed to predict the outcome from only information relevant to T . If this\nholds, we expect Dragonnet to improve signi\ufb01cantly over the baseline when there is a large number of\ncovariates that in\ufb02uence only Y (i.e., not T ). These covariates are effectively \u201dnoise\u201d for the causal\nestimation since they are irrelevant for confounding. As illustrated in Figure 3, when most of the\neffect on Y is from confounding variables, the differences between Dragonnet and the baseline are\nnot signi\ufb01cant. As the number of covariates that only in\ufb02uence Y increases, Dragonnet becomes a\nbetter estimator.\n\nThird, Dragonnet and TARNET should perform equally well with in\ufb01nite data. With \ufb01nite data, we\nexpect Dragonnet to be more data ef\ufb01cient as it discards covariates that are irrelevant for confounding.\nWe verify this intuition by comparing models performance with various amount of data. We \ufb01nd\nDragonnet\u2019s improvement is more drastic with smaller-sized data. See Appendix A for details.\n\n5.4 When does targeted regularization work?\n\nThe guarantees from non-parametric theory are asymptotic, and apply in regimes where the estimated\nmodels closely approximate the true values. We divide the datasets according to the error of the\nsimple (Q-only) baseline estimator. As shown in table 6, in cases where the initial estimator is good,\nTMLE and targeted regularization behave similarly. In cases where the initial estimator is poor,\nTMLE signi\ufb01cantly degrades estimation quality, but targeted regularization does not. It appears that\nadapting the entire model to satisfy the non-parametric estimating equation avoids some bad \ufb01nite\n\n4\u201cNEDnet\u201d because the network is beheaded after the \ufb01rst stage.\n\n8\n\n\fFigure 2: Dragonnet has worse prediction loss on\nthe held out data than baseline, but better estimation\nquality. The estimation error and loss are from a\nseparate run of the ACIC dataset where we held out\n30% of data to compute the loss.\n\nFigure 3: Dragonnet improves over the base-\nline if many covariates are irrelevant for treat-\nment. We strati\ufb01ed the ACIC datasets by the\nnumber of irrelevant covariates and compared\nthe median MAE across strata.\n\nsample effects. We do not have a satisfactory theoretical explanation for this. Understanding this\nphenomena is an important direction for future work.\n\n6 Discussion\n\nThere are a number of directions for future work. Foremost, although TMLE and targeted reg-\nularization are conceptually similar, the methods have different performance in our experiments.\nUnderstanding the root causes of this behavior may shed insight on the practical use of non-parametric\nestimation methods. Relatedly, another promising direction is to adapt the well-developed literature\non TMLE [e.g., vR11; LG16] to more advanced targeted regularization methods. For instance, there\nare a number of TMLE approaches to estimating the average treatment effect on the treated (ATT). It\nis unclear which of these, if any, will yield a good targeted-regularization type procedure. Generally,\nextending the methods here to other causal estimands and mediation analysis is an important problem.\n\nThere are also interesting questions about Dragonnet-type architectures. We motivated Dragonnet with\nthe intuition that we should use only the covariate information that is relevant to both the treatment\nassignment and outcome. Our empirical results support this intuition. However, in other contexts, this\nintuition breaks down. For example, in RCT, where covariates only affect the outcome, adjustment\ncan increase power [SSI13]. This is because adjustment may reduce the effective variance of the\noutcome, and double robust methods can be used to ensure consistency (the treatment assignment\nmodel is known trivially). Our motivating intuition is well supported in the large-data, unknown\npropensity-score model case we consider. It would be valuable to have a clear articulation of the\ntrade-offs involved and practical guidelines for choosing covariates for causal adjustment. As an\nillustration, recent papers have used Dragonnet-type models for causal adjustment with black-box\nembedding methods [VWB19; VSB19]. They achieve good estimation accuracy, but it remains\nunclear exactly what trade-offs may being made.\n\nIn a different direction, our experiments do not support the routine use of data-splitting in effect\nestimation. Existing methods have commonly split the data into train and test sets and used predictions\non the test set to compute the downstream estimator. This technique has some theoretical justi\ufb01cation\n[Che+17b] (in a K-fold variant), but signi\ufb01cantly degrades performance in our experiments. We\nnote that this is also true in our preliminary (unreported) experiments with K-fold data splitting. A\nclearer understanding of why and when data splitting is appropriate would be high impact. We note\nthat Farrell et al. [FLM18] prove that data reuse does not invalidate estimation when using neural\nnetworks.\n\nAcknowledgements\n\nWe are thankful to Yixin Wang, Dhanya Sridhar, Jackson Loper, Roy Adams, and Shira Mitchell\nfor helpful comments and discussions. This work was supported by ONR N00014-15-1-2209, ONR\n133691-5102004 , NIH 5100481-5500001084, NSF CCF-1740833, FA 8750-14-2-0009, the Alfred\n\n9\n\n\fP. Sloan Foundation, the John Simon Guggenheim Foundation, Facebook, Amazon, IBM, and the\ngovernment of Canada through NSERC. The GPUs used for this research were donated by the\nNVIDIA Corporation.\n\nReferences\n\n[AS17]\n\n[AWS17]\n\n[Bot+12]\n\nA. Alaa and M. van der Schaar. \u201cBayesian inference of individualized treatment effects\nusing multi-task gaussian processes\u201d. In: arXiv e-prints arXiv:1704.02801 (2017).\nA. M. Alaa, M. Weisz, and M. van der Schaar. \u201cDeep counterfactual networks with\npropensity-dropout\u201d. In: arXiv e-prints arXiv:1706.05966 (2017).\nL. Bottou, J. Peters, J. Qui\u02dcnonero-Candela, D. X. Charles, D. M. Chickering, E. Por-\ntugaly, D. Ray, P. Simard, and E. Snelson. \u201cCounterfactual reasoning and learning\nsystems\u201d. In: arXiv preprint arXiv:1209.2355 (2012).\n\n[Che+17a] V. Chernozhukov, D. Chetverikov, M. Demirer, E. Du\ufb02o, C. Hansen, W. Newey, and\nJ. Robins. \u201cDouble/debiased machine learning for treatment and structural parameters\u201d.\nIn: The Econometrics Journal (2017).\n\n[CH08]\n\n[Che+17b] V. Chernozhukov, D. Chetverikov, M. Demirer, E. Du\ufb02o, C. Hansen, and W. Newey.\n\u201cDouble/debiased/neyman machine learning of treatment effects\u201d. In: American Eco-\nnomic Review 5 (2017).\nS. R. Cole and M. A. Hern\u00b4an. \u201cConstructing inverse probability weights for marginal\nstructural models.\u201d In: American Journal of Epidemiology (2008).\nV. Dorie. Non-parametrics for Causal Inference. https://github.com/vdorie/\nnpci. 2016.\n\n[Dor16]\n\n[Hil11]\n\n[JSS16]\n\n[KS07]\n\n[Gir+14]\n\n[FLM18] M. H. Farrell, T. Liang, and S. Misra. \u201cDeep Neural Networks for Estimation and\nInference: Application to Causal Effects and Other Semiparametric Estimands\u201d. In:\narXiv e-prints: arxiv: 1809.09953 (2018).\nR. B. Girshick, J. Donahue, T. Darrell, and J. Malik. \u201cRich feature hierarchies for\naccurate object detection and semantic segmentation\u201d. In: 2014 IEEE Conference on\nComputer Vision and Pattern Recognition (2014).\nJ. L. Hill. \u201cBayesian nonparametric modeling for causal inference\u201d. In: Journal of\nComputational and Graphical Statistics 1 (2011).\nF. D. Johansson, U. Shalit, and D. Sontag. \u201cLearning representations for counterfactual\ninference\u201d. In: arXiv e-prints arXiv:1605.03661 (2016).\nJ. D. Y. Kang and J. L. Schafer. \u201cDemystifying double robustness: a comparison of\nalternative strategies for estimating a population mean from incomplete data\u201d. In: Statist.\nSci. 4 (2007).\nE. H. Kennedy. \u201cSemiparametric theory and empirical processes in causal inference\u201d.\nIn: Statistical causal inferences and their applications in public health research. 2016.\nM van der Laan and S Gruber. \u201cOne-step targeted minimum loss-based estimation based\non universal least favorable one-dimensional submodels\u201d. In: The International Journal\nof Biostatistics (2016).\nC. Louizos, U. Shalit, J. M. Mooij, D. Sontag, R. Zemel, and M. Welling. \u201cCausal effect\ninference with deep latent-variable models\u201d. In: NEURIPS. 2017.\nM. F. MacDorman and J. O. Atkinson. \u201cInfant mortality statistics from the linked\nbirth/infant death\u201d. In: Mon Vital Stat Rep, 46(suppl 2):1\u201322 (1998).\nF. J. Potter. \u201cThe effect of weight trimming on nonlinear survey estimates\u201d. In: Pro-\nceedings of the American Statistical Association, Section on Survey Research Methods.\n1993.\nJ. M. Robins. \u201cRobust estimation in sequentially ignorable missing data and causal\ninference models\u201d. In: ASA Proceedings of the Section on Bayesian Statistical Science\n(2000).\nJ. M. Robins, A. Rotnitzky, and M. van der Laan. \u201cOn pro\ufb01le likelihood: comment\u201d. In:\nJournal of the American Statistical Association 450 (2000).\nP. R. Rosenbaum and D. B. Rubin. \u201cThe central role of the propensity score in observa-\ntional studies for causal effects\u201d. In: Biometrika 1 (1983).\n\n[Lou+17]\n\n[RRL00]\n\n[Rob00]\n\n[MA98]\n\n[Pot93]\n\n[Ken16]\n\n[LG16]\n\n[RR83]\n\n10\n\n\f[SSI13]\n\n[SRR99]\n\n[SLK18]\n\n[SJS16]\n\n[Shi+18]\n\n[vR11]\n\nN. Saquib, J. Saquib, and J. P. A. Ioannidis. \u201cPractices and impact of primary outcome\nadjustment in randomized controlled trials: meta-epidemiologic study\u201d. In: BMJ (2013).\nD. O. Scharfstein, A. Rotnitzky, and J. M. Robins. \u201cAdjusting for nonignorable drop-out\nusing semiparametric nonresponse models\u201d. In: Journal of the American Statistical\nAssociation 448 (1999).\nP. Schwab, L. Linhardt, and W. Karlen. \u201cPerfect match: a simple method for learning\nrepresentations for counterfactual inference with neural networks\u201d. In: arXiv e-prints\narXiv:1810.00656 (2018).\nU. Shalit, F. D. Johansson, and D. Sontag. \u201cEstimating individual treatment effect:\ngeneralization bounds and algorithms\u201d. In: arXiv e-prints arXiv:1606.03976 (2016).\nY. Shimoni, C. Yanover, E. Karavani, and Y. Goldschmnidt. \u201cBenchmarking frame-\nwork for performance-evaluation of causal inference analysis\u201d. In: ArXiv preprint\narXiv:1802.05046 (2018).\nM. van der Laan and S. Rose. Targeted Learning: Causal Inference for Observational\nand Experimental Data. 2011.\n\n[VSB19]\n\n[VWB19] V. Veitch, Y. Wang, and D. M. Blei. \u201cUsing embeddings to correct for unobserved\nconfounding in networks\u201d. In: Advances in Neural Information Processing Systems.\n2019.\nV. Veitch, D. Sridhar, and D. M. Blei. \u201cUsing text embeddings for causal inference\u201d. In:\narXiv e-prints (2019).\nJ. Yoon, J. Jordon, and M. van der Schaar. \u201cGanite: estimation of individualized treat-\nment effects using generative adversarial nets\u201d. In: International Conference on Learn-\ning Representations. 2018.\n\n[YJS18]\n\n11\n\n\f", "award": [], "sourceid": 1443, "authors": [{"given_name": "Claudia", "family_name": "Shi", "institution": "Columbia University"}, {"given_name": "David", "family_name": "Blei", "institution": "Columbia University"}, {"given_name": "Victor", "family_name": "Veitch", "institution": "Columbia University"}]}