用 DoWhy 和 EconML 估计条件平均因果效应

条件平均因果效应(CATE) with DoWhy and EconML

这是一项实验性功能 where we use EconML methods from DoWhy. Using EconML allows CATE estimation using different methods.

DoWhy中因果推理的所有四个步骤都保持不变:建模,识别,估计和反驳。 关键区别在于我们现在在估算步骤中调用econml方法。 还有一个使用线性回归的简单示例 to understand the intuition behind CATE estimators.

[1]:
import os, sys
sys.path.insert(1, os.path.abspath("../../../"))  # for dowhy source code
[1]:
import numpy as np
import pandas as pd
import logging

import dowhy
from dowhy import CausalModel
import dowhy.datasets

import econml
import warnings
warnings.filterwarnings('ignore')
[5]:
data = dowhy.datasets.linear_dataset(10, num_common_causes=4, num_samples=10000,
                                    num_instruments=2, num_effect_modifiers=2,
                                     num_treatments=1,
                                    treatment_is_binary=False,
#                                     num_discrete_common_causes=2,
#                                     num_discrete_effect_modifiers=0,
#                                     one_hot_encode=False
                                    )
df=data['df']
df.head()
[5]:
X0 X1 Z0 Z1 W0 W1 W2 W3 v0 y
0 -2.521464 0.155409 1.0 0.559138 2.007452 0.194488 -0.737834 0.949250 18.744430 144.319845
1 -0.272124 -0.594038 0.0 0.562656 2.412827 -0.966005 1.208103 -0.916247 5.594197 56.033465
2 1.471747 0.949394 0.0 0.537858 0.078511 -0.334681 1.886087 -0.414642 10.707244 143.527164
3 -0.714296 -1.157574 1.0 0.650554 -0.396513 2.081661 -1.689673 0.317751 16.975189 130.548563
4 -0.895241 0.463742 1.0 0.773794 -0.099374 0.707962 1.917918 -2.526865 16.929208 165.811436
[7]:
model = CausalModel(data=data["df"],
                    treatment=data["treatment_name"], outcome=data["outcome_name"],
                    graph=data["gml_graph"])

model.view_model()

from IPython.display import Image, display
display(Image(filename="causal_model.png"))
INFO:dowhy.causal_model:Model to find the causal effect of treatment ['v0'] on outcome ['y']
../_images/example_notebooks_dowhy-conditional-treatment-effects_4_1.png
[8]:
identified_estimand= model.identify_effect(proceed_when_unidentifiable=True)
print(identified_estimand)
INFO:dowhy.causal_identifier:Common causes of treatment and outcome:['Unobserved Confounders', 'W0', 'W1', 'W3', 'W2']
WARNING:dowhy.causal_identifier:If this is observed data (not from a randomized experiment), there might always be missing confounders. Causal effect cannot be identified perfectly.
INFO:dowhy.causal_identifier:Continuing by ignoring these unobserved confounders because proceed_when_unidentifiable flag is True.
INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:['Z0', 'Z1']
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

线性模型

首先,让我们使用线性模型建立一些直觉来估计CATE。可以将 effect modifiers(导致异质因果效应)建模为与 treatment 的交互项。因此,它们的值 modulates the effect of treatment.

Below the estimated effect of changing treatment from 0 to 1.

[9]:
linear_estimate = model.estimate_effect(identified_estimand,
                                        method_name="backdoor.linear_regression",
                                       control_value=0,
                                       treatment_value=1)
print(linear_estimate)
INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2+v0*X1+v0*X0
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2+v0*X1+v0*X0
## Estimate
Value: 10.00000000000001

EconML 方法

现在,我们从EconML包转向更高级的方法来估算CATE。

首先,让我们看一下 double machine learning estimator。Method_name 对应于我们要使用的类的标准名称。对于 double ML,它是“ econml.dml.DMLCateEstimator”。

Target units 定义了要计算因果估计的 units。 可以是 a lambda function filter on the original dataframe, a new Pandas dataframe, or a string corresponding to the three main kinds of target units (“ate”, “att” and “atc”). 下面我们显示一个lambda函数的示例。

Method_params 是直接传参数给 EconML. 有关允许的参数的详细信息,请参阅EconML文档。

[10]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DMLCateEstimator",
                                     control_value = 0,
                                     treatment_value = 1,
                                 target_units = lambda df: df["X0"]>1,  # condition used for CATE
                                 confidence_intervals=False,
                                method_params={"init_params":{'model_y':GradientBoostingRegressor(),
                                                              'model_t': GradientBoostingRegressor(),
                                                              "model_final":LassoCV(),
                                                              'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
                                               "fit_params":{}})
print(dml_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
INFO:numexpr.utils:NumExpr defaulting to 4 threads.
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 12.377928464975215

[11]:
print("True causal estimate is", data["ate"])
True causal estimate is 9.978706628277308
[12]:
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DMLCateEstimator",
                                     control_value = 0,
                                     treatment_value = 1,
                                 target_units = 1,  # condition used for CATE
                                 confidence_intervals=False,
                                method_params={"init_params":{'model_y':GradientBoostingRegressor(),
                                                              'model_t': GradientBoostingRegressor(),
                                                              "model_final":LassoCV(),
                                                              'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
                                               "fit_params":{}})
print(dml_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 9.919618287281304

CATE Object 和置信区间

[13]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand,
                                     method_name="backdoor.econml.dml.DMLCateEstimator",
                                     target_units = lambda df: df["X0"]>1,
                                     confidence_intervals=True,
                                     method_params={"init_params":{'model_y':GradientBoostingRegressor(),
                                                              'model_t': GradientBoostingRegressor(),
                                                              "model_final":LassoCV(),
                                                              'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
                                               "fit_params":{
                                                               'inference': 'bootstrap',
                                                            }
                                              })
print(dml_estimate)
print(dml_estimate.cate_estimates[:10])
print(dml_estimate.effect_intervals)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:   17.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  1.2min finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:    0.1s finished
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 12.29197180689786

[[13.07887647]
 [10.77809375]
 [13.71470538]
 [12.2937543 ]
 [11.72053516]
 [10.7891811 ]
 [12.77571337]
 [13.15299296]
 [12.73256256]
 [ 9.99162377]]
(array([[12.87146448],
       [10.63969365],
       [13.49162136],
       [12.10427793],
       [11.55423673],
       [10.65016767],
       [12.53653385],
       [12.94199474],
       [12.51816049],
       [ 9.88026419],
       [11.87817487],
       [11.71871174],
       [ 9.4148536 ],
       [14.34614983],
       [10.46412705],
       [12.56785538],
       [13.83984034],
       [13.4679312 ],
       [14.35042059],
       [12.35224681],
       [11.61582257],
       [10.29811664],
       [13.59903865],
       [ 8.55052931],
       [13.45880418],
       [13.77643288],
       [12.19587955],
       [10.83290248],
       [10.7777583 ],
       [13.94236224],
       [13.01055602],
       [12.50870928],
       [14.25350222],
       [11.481481  ],
       [13.74603396],
       [10.69244058],
       [13.45605579],
       [11.66522   ],
       [14.25761843],
       [11.36297136],
       [14.88296605],
       [13.55224443],
       [14.17487503],
       [11.1713894 ],
       [10.68895161],
       [15.59339931],
       [10.93335862],
       [11.61122488],
       [13.07660271],
       [13.48912346],
       [10.73406571],
       [10.62513206],
       [13.56689894],
       [12.15347773],
       [13.92805453],
       [10.76597466],
       [12.465917  ],
       [13.00804687],
       [13.78734038],
       [11.24303223],
       [13.34725069],
       [14.0715816 ],
       [10.90278155],
       [11.42205258],
       [11.32769271],
       [ 9.86228705],
       [11.39895884],
       [12.73089948],
       [11.50841076],
       [13.59595896],
       [14.09253805],
       [ 8.55421913],
       [12.43520073],
       [13.38545776],
       [ 9.92893184],
       [14.1753038 ],
       [10.10168566],
       [12.11887461],
       [10.64007044],
       [10.66021766],
       [14.03788148],
       [10.2426245 ],
       [13.48656247],
       [10.28762964],
       [13.32024058],
       [11.18720081],
       [13.17964278],
       [12.20751813],
       [13.27962142],
       [11.07746677],
       [13.96154738],
       [10.52742145],
       [14.66574205],
       [11.45737787],
       [12.13263397],
       [ 7.28664916],
       [ 9.1369182 ],
       [12.10918153],
       [12.58778935],
       [10.72293532],
       [13.5342271 ],
       [10.49410814],
       [ 9.33320491],
       [12.45948473],
       [11.45174567],
       [10.24555438],
       [13.62177991],
       [ 9.9130219 ],
       [14.34600261],
       [14.43809285],
       [ 9.66878299],
       [10.07858596],
       [12.05202067],
       [12.70471663],
       [12.0662915 ],
       [10.51067217],
       [ 8.8605647 ],
       [ 9.38853094],
       [10.79353383],
       [13.65557565],
       [11.48356879],
       [10.45198197],
       [13.42241315],
       [11.64821402],
       [12.32981612],
       [12.3897112 ],
       [12.78012245],
       [ 9.3541085 ],
       [12.39124123],
       [11.49473942],
       [15.60255523],
       [11.03254036],
       [14.48455651],
       [12.42579355],
       [13.89661424],
       [10.86360454],
       [11.57337619],
       [13.54398826],
       [11.89090609],
       [10.84262207],
       [13.54875206],
       [11.57927462],
       [12.50655234],
       [13.27503301],
       [12.27626145],
       [10.89800332],
       [13.80514938],
       [10.56707828],
       [ 9.66292243],
       [13.79325233],
       [13.05331069],
       [13.08415392],
       [ 9.88640943],
       [11.1954592 ],
       [11.67724679],
       [14.79144052],
       [13.73341175],
       [10.29388788],
       [13.25140371],
       [10.18888078],
       [12.00530852],
       [11.13523626],
       [11.42142756],
       [13.00333764],
       [10.55216177],
       [13.7069508 ],
       [ 9.939213  ],
       [ 8.44011804],
       [11.36589106],
       [12.67601568],
       [12.7563392 ],
       [15.03531603],
       [14.06365864],
       [13.64215571],
       [12.34555147],
       [10.77237273],
       [10.44920636],
       [13.75170402],
       [13.22655426],
       [11.97592994],
       [11.06976618],
       [10.6006746 ],
       [14.70619444],
       [12.53276561],
       [ 9.77386486],
       [10.24176779],
       [14.65933942],
       [13.23228319],
       [13.41149638],
       [15.63274955],
       [11.54459279],
       [12.16636735],
       [ 9.57313977],
       [ 9.68443792],
       [11.92627929],
       [10.75033796],
       [13.75137642],
       [ 9.37851181],
       [13.59691607],
       [11.99425813],
       [13.77435554],
       [12.32905697],
       [10.0170267 ],
       [11.98794031],
       [12.3031148 ],
       [11.46157776],
       [16.51520556],
       [12.82192023],
       [11.48265185],
       [12.23133619],
       [12.41014321],
       [13.41517545],
       [10.00409418],
       [12.93733235],
       [10.92526768],
       [11.21242851],
       [11.76314686],
       [13.48521122],
       [12.84019928],
       [10.52288634],
       [13.84317597],
       [11.52142684],
       [11.18409827],
       [12.50920479],
       [11.56451181],
       [11.77683623],
       [12.00109072],
       [10.36001371],
       [13.97656052],
       [13.72559367],
       [ 8.93137977],
       [11.50617527],
       [11.17749833],
       [ 9.66613345],
       [11.83224923],
       [10.41976534],
       [13.39673001],
       [12.29996792],
       [12.58169478],
       [11.88588282],
       [ 8.56474626],
       [13.35990226],
       [11.12620209],
       [ 9.97506967],
       [11.82271745],
       [11.35257882],
       [13.33420405],
       [11.69218399],
       [12.69285982],
       [11.43339827],
       [11.56324602],
       [11.90482041],
       [13.49130632],
       [12.27140265],
       [14.18647427],
       [14.03699062],
       [13.74436576],
       [11.53214309],
       [11.05060965],
       [12.71232129],
       [12.10131898],
       [12.95535412],
       [ 8.28023453],
       [12.65453404],
       [13.01294309],
       [11.49852355],
       [12.81940313],
       [12.36242016],
       [14.55547993],
       [12.94852281],
       [12.09273988],
       [12.59353021],
       [11.24883235],
       [ 9.72728484],
       [14.56589286],
       [11.71223304],
       [13.38675392],
       [11.35694388],
       [ 9.16338904],
       [11.33505823],
       [ 9.6709106 ],
       [11.37968799],
       [12.27489278],
       [ 9.20931755],
       [15.72028541],
       [12.6571529 ],
       [14.12789102],
       [13.78419097],
       [13.94380542],
       [11.41614025],
       [11.91983963],
       [11.95342121],
       [11.78427156],
       [10.53913693],
       [12.61911117],
       [13.0776739 ],
       [12.99906085],
       [13.16019801],
       [ 9.43333915],
       [14.1179648 ],
       [13.68698569],
       [10.79001288],
       [11.45795032],
       [10.87676189],
       [ 9.70138916],
       [12.30902341],
       [11.02233042],
       [11.11225944],
       [10.59053996],
       [13.67246118],
       [13.01201609],
       [12.4485459 ],
       [12.14441008],
       [14.65459005],
       [14.2205289 ],
       [14.33755189],
       [ 9.76317934],
       [10.39151872],
       [14.06892891],
       [11.69943786],
       [14.13612428],
       [10.31022793],
       [13.59109518],
       [12.43317473],
       [12.73157091],
       [11.7809359 ],
       [12.05809952],
       [13.57855211],
       [11.81946644],
       [12.03365272],
       [ 9.69284722],
       [12.61077958],
       [10.1974851 ],
       [10.69373057],
       [10.14790369],
       [ 9.5376557 ],
       [13.62622577],
       [13.69692541],
       [12.96195904],
       [13.68401556],
       [11.64764038],
       [13.35338454],
       [12.2906384 ],
       [13.63119092],
       [11.84476677],
       [ 9.77711082],
       [16.14865539],
       [12.66458613],
       [13.91513767],
       [10.89029429],
       [12.15157721],
       [13.60672934],
       [ 9.85200774],
       [15.24536968],
       [ 8.75469341],
       [12.56884484],
       [11.59634557],
       [11.83788921],
       [11.94704183],
       [10.86275636],
       [14.29852981],
       [13.49962816],
       [12.44167422],
       [14.78408918],
       [12.69767754],
       [13.62519429],
       [14.36100415],
       [14.37605106],
       [11.5872207 ],
       [12.22068667],
       [14.75781157],
       [12.51409019],
       [10.86837543],
       [12.46988342],
       [12.61025858],
       [11.24015865],
       [11.18835223],
       [10.79716704],
       [11.16863208],
       [17.8891018 ],
       [ 9.02508256],
       [10.93685663],
       [11.68417659],
       [13.13925115],
       [12.59156738],
       [10.55995775],
       [14.05418231],
       [10.10209501],
       [ 9.6423258 ],
       [12.97532986],
       [12.29441667],
       [12.55389477],
       [10.68272711],
       [10.62653066],
       [13.6699227 ],
       [10.52616547],
       [11.70239017],
       [10.14030339],
       [13.13935322],
       [11.95437062],
       [12.17341794],
       [12.91915496],
       [11.32004661],
       [11.64838252],
       [12.56425687],
       [11.77844128],
       [14.65835339],
       [ 9.70707398],
       [14.9300481 ],
       [12.41583695],
       [10.51673836],
       [15.44266874],
       [10.29119328],
       [12.4698904 ],
       [11.66472781],
       [13.25203398],
       [14.60736181],
       [11.7148831 ],
       [ 8.71781083],
       [11.67637186],
       [14.19799786],
       [11.69866959],
       [ 9.92887925],
       [ 9.71813268],
       [12.02030602],
       [12.41740233],
       [10.47762016],
       [12.20119777],
       [10.48521344],
       [10.67327157],
       [11.87827654],
       [11.58606279],
       [ 8.9557734 ],
       [12.64025621],
       [11.79850084],
       [14.84465251],
       [11.85336354],
       [15.0924792 ],
       [11.49130477],
       [13.30705909],
       [10.28816464],
       [10.83173928],
       [10.90623076],
       [12.63904975],
       [11.19355693],
       [14.31466264],
       [11.92317275],
       [12.35268268],
       [11.0891759 ],
       [12.68793201],
       [12.61899466],
       [12.5101173 ],
       [12.87242751],
       [14.39594312],
       [12.5929976 ],
       [14.19507331],
       [11.5081742 ],
       [11.09412734],
       [10.90389304],
       [12.19205614],
       [14.54447031],
       [12.03885945],
       [11.07151343],
       [12.15146514],
       [13.4617387 ]]), array([[13.07894792],
       [10.84575642],
       [13.69590258],
       [12.31232205],
       [11.75347935],
       [10.85714818],
       [12.84426434],
       [13.1444251 ],
       [12.76482436],
       [10.06939464],
       [12.06387706],
       [11.89619409],
       [ 9.61847512],
       [14.58858477],
       [10.63703602],
       [12.74834127],
       [14.07150676],
       [13.69165376],
       [14.59545699],
       [12.6189523 ],
       [11.77948879],
       [10.55910703],
       [13.81663594],
       [ 8.80099987],
       [13.66111391],
       [13.99053663],
       [12.40390807],
       [11.01583279],
       [10.94622089],
       [14.1953548 ],
       [13.22704109],
       [12.69790294],
       [14.48992041],
       [11.71906794],
       [13.98686407],
       [10.91538098],
       [13.66939961],
       [11.91477002],
       [14.49411742],
       [11.59968286],
       [15.17467009],
       [13.78771275],
       [14.40876507],
       [11.38332409],
       [10.86291076],
       [15.91855765],
       [11.12202599],
       [11.8139945 ],
       [13.2785676 ],
       [13.75410866],
       [10.94500534],
       [10.81665283],
       [13.77728066],
       [12.35437864],
       [14.15992469],
       [10.96932092],
       [12.65282316],
       [13.20062563],
       [14.05003482],
       [11.40968384],
       [13.54395628],
       [14.32499006],
       [11.10612758],
       [11.653047  ],
       [11.5233189 ],
       [10.13225458],
       [11.57943789],
       [12.96977652],
       [11.67634853],
       [13.80597741],
       [14.32265421],
       [ 8.78074485],
       [12.62539097],
       [13.62372143],
       [10.13220664],
       [14.41165466],
       [10.30313441],
       [12.29137817],
       [10.85908844],
       [10.86997202],
       [14.31642074],
       [10.45987268],
       [13.69175387],
       [10.49835313],
       [13.52140227],
       [11.40790524],
       [13.40323376],
       [12.38705698],
       [13.47856897],
       [11.25691355],
       [14.19044117],
       [10.70758337],
       [14.92757739],
       [11.62831345],
       [12.30570674],
       [ 7.64964432],
       [ 9.34532699],
       [12.28128793],
       [12.77242366],
       [10.91473451],
       [13.7431299 ],
       [10.66586304],
       [ 9.60098011],
       [12.63828052],
       [11.61905829],
       [10.452038  ],
       [13.83878977],
       [10.16672066],
       [14.5952389 ],
       [14.69404573],
       [ 9.89345553],
       [10.27842252],
       [12.22190878],
       [12.90384068],
       [12.23505028],
       [10.71683181],
       [ 9.17537065],
       [ 9.61016646],
       [10.95801958],
       [13.87546345],
       [11.64645976],
       [10.65094227],
       [13.65933812],
       [11.83807144],
       [12.5110824 ],
       [12.56438347],
       [12.98134699],
       [ 9.55618822],
       [12.57505298],
       [11.70467296],
       [15.9181746 ],
       [11.20365837],
       [14.72973041],
       [12.61204303],
       [14.1176988 ],
       [11.02560145],
       [11.73771415],
       [13.75031897],
       [12.05927797],
       [11.06983217],
       [13.75961738],
       [11.74870414],
       [12.68565666],
       [13.47632584],
       [12.45289081],
       [11.11247564],
       [14.0316403 ],
       [10.73797911],
       [ 9.85968283],
       [14.01599045],
       [13.26567286],
       [13.28723205],
       [10.12892574],
       [11.36787909],
       [11.96631688],
       [15.07638762],
       [13.95008643],
       [10.47858977],
       [13.47164153],
       [10.37580755],
       [12.19070515],
       [11.29649832],
       [11.62707912],
       [13.19232037],
       [10.72747008],
       [13.92321316],
       [10.1347228 ],
       [ 8.67952567],
       [11.52706188],
       [12.86262296],
       [13.00120632],
       [15.34565835],
       [14.33304574],
       [13.88104819],
       [12.58757772],
       [10.94205099],
       [10.63880514],
       [13.98795285],
       [13.43748301],
       [12.14712398],
       [11.23981735],
       [10.77813238],
       [14.98107644],
       [12.75519106],
       [10.0270005 ],
       [10.46155806],
       [14.91236854],
       [13.42385938],
       [13.61151155],
       [15.93569038],
       [11.70527232],
       [12.36329822],
       [ 9.79087368],
       [ 9.89006077],
       [12.10639068],
       [10.91767124],
       [13.97416093],
       [ 9.63250623],
       [13.82480848],
       [12.16197299],
       [13.98801989],
       [12.52974574],
       [10.19889155],
       [12.15568266],
       [12.48171293],
       [11.62335875],
       [16.86451547],
       [13.01885299],
       [11.68630172],
       [12.4091232 ],
       [12.59474146],
       [13.61374044],
       [10.18719813],
       [13.14822494],
       [11.10638275],
       [11.37706048],
       [11.92855488],
       [13.70415573],
       [13.06224186],
       [10.7037102 ],
       [14.07465468],
       [11.6932775 ],
       [11.39504106],
       [12.68941289],
       [11.73351064],
       [11.94484994],
       [12.19332483],
       [10.55123003],
       [14.25136772],
       [13.94359492],
       [ 9.24508724],
       [11.73317058],
       [11.34539596],
       [ 9.86223715],
       [12.0394726 ],
       [10.61688767],
       [13.59578139],
       [12.47307441],
       [12.84799256],
       [12.12199599],
       [ 8.82787628],
       [13.59190191],
       [11.30036597],
       [10.17103972],
       [12.00173376],
       [11.56842465],
       [13.56036342],
       [11.88751906],
       [12.89196416],
       [11.59627951],
       [11.73206905],
       [12.07283589],
       [13.73674178],
       [12.4633716 ],
       [14.42055929],
       [14.28186157],
       [13.96295546],
       [11.79873451],
       [11.36397286],
       [12.91362479],
       [12.28429717],
       [13.14080192],
       [ 8.53768983],
       [12.84093397],
       [13.22400308],
       [11.6941683 ],
       [13.02647421],
       [12.5470457 ],
       [14.80662333],
       [13.13563831],
       [12.26415901],
       [12.78020333],
       [11.44372489],
       [ 9.91833421],
       [14.81724042],
       [11.94907772],
       [13.61663202],
       [11.52133777],
       [ 9.36330527],
       [11.50277897],
       [ 9.86089524],
       [11.59119489],
       [12.45387866],
       [ 9.45278179],
       [16.05417394],
       [12.85200218],
       [14.38519643],
       [13.99780066],
       [14.17443287],
       [11.60165443],
       [12.11103529],
       [12.15113256],
       [11.95415841],
       [10.71023757],
       [12.79762825],
       [13.2981882 ],
       [13.20891865],
       [13.42377608],
       [ 9.64827348],
       [14.35456871],
       [13.92566437],
       [10.99939327],
       [11.65528914],
       [11.04994846],
       [ 9.92106329],
       [12.48806927],
       [11.18931425],
       [11.33297902],
       [10.77017194],
       [13.91126122],
       [13.2107673 ],
       [12.63473459],
       [12.33326877],
       [14.98380379],
       [14.46041511],
       [14.58334991],
       [ 9.95281541],
       [10.59463414],
       [14.30753279],
       [11.87654692],
       [14.3931904 ],
       [10.55168569],
       [13.80315741],
       [12.61872836],
       [12.92412375],
       [11.94505195],
       [12.25156781],
       [13.8079198 ],
       [12.02113797],
       [12.2489677 ],
       [ 9.88561159],
       [12.80330951],
       [10.39572344],
       [10.9004194 ],
       [10.32944656],
       [ 9.76611252],
       [13.86443622],
       [13.92716434],
       [13.16217697],
       [13.89469329],
       [11.83208676],
       [13.58327243],
       [12.46610004],
       [13.85931052],
       [12.09642662],
       [ 9.97731336],
       [16.50113118],
       [12.9005419 ],
       [14.14390891],
       [11.09512146],
       [12.32543621],
       [13.82527797],
       [10.05273339],
       [15.53150804],
       [ 8.97826605],
       [12.74632713],
       [11.80154538],
       [12.01529029],
       [12.11330114],
       [11.06446525],
       [14.53694978],
       [13.75575814],
       [12.62757453],
       [15.04564175],
       [12.88230155],
       [13.86984434],
       [14.65061708],
       [14.64039702],
       [11.75746314],
       [12.39670343],
       [15.01843328],
       [12.73950981],
       [11.0456934 ],
       [12.64732602],
       [12.91153646],
       [11.41520105],
       [11.40522134],
       [10.98135761],
       [11.34642021],
       [18.30848949],
       [ 9.24009505],
       [11.10670151],
       [11.84699945],
       [13.35522616],
       [12.94369597],
       [10.75195423],
       [14.29798799],
       [10.3404489 ],
       [ 9.84047489],
       [13.18302594],
       [12.46754479],
       [12.77318462],
       [10.88590586],
       [10.8146032 ],
       [13.88969196],
       [10.69894254],
       [11.86616752],
       [10.31778116],
       [13.33549707],
       [12.12081293],
       [12.34818768],
       [13.21361609],
       [11.55605457],
       [11.81496619],
       [12.77248692],
       [11.94411167],
       [14.91386647],
       [ 9.90423407],
       [15.21423249],
       [12.59001717],
       [10.76051269],
       [15.74643183],
       [10.50964235],
       [12.65098043],
       [11.91006861],
       [13.50735731],
       [14.86053726],
       [11.90193117],
       [ 8.94482444],
       [11.84485628],
       [14.43969222],
       [11.86809715],
       [10.12221182],
       [ 9.96428682],
       [12.19763954],
       [12.60229413],
       [10.78121232],
       [12.45916522],
       [10.68948636],
       [10.85049627],
       [12.04439998],
       [11.76811358],
       [ 9.19299667],
       [12.83913671],
       [12.01661934],
       [15.12081177],
       [12.02667628],
       [15.36597381],
       [11.67436218],
       [13.54464104],
       [10.46902241],
       [11.02452611],
       [11.10377992],
       [12.83162831],
       [11.41408326],
       [14.56105048],
       [12.12005516],
       [12.54298625],
       [11.33483039],
       [12.88319648],
       [12.86458803],
       [12.69920402],
       [13.05534857],
       [14.67111915],
       [12.7846298 ],
       [14.45120732],
       [11.67365969],
       [11.2566398 ],
       [11.10900855],
       [12.36760447],
       [14.7956271 ],
       [12.25965015],
       [11.27881007],
       [12.39934097],
       [13.70644115]]))

New inputs 的因果效应

Can provide a new inputs as target units and estimate CATE on them.

[14]:
test_cols= data['effect_modifier_names'] # only need effect modifiers' values
test_arr = [np.random.uniform(0,1, 10) for _ in range(len(test_cols))] # all variables are sampled uniformly, sample of 10
test_df = pd.DataFrame(np.array(test_arr).transpose(), columns=test_cols)
dml_estimate = model.estimate_effect(identified_estimand,
                                     method_name="backdoor.econml.dml.DMLCateEstimator",
                                     target_units = test_df,
                                     confidence_intervals=False,
                                     method_params={"init_params":{'model_y':GradientBoostingRegressor(),
                                                              'model_t': GradientBoostingRegressor(),
                                                              "model_final":LassoCV(),
                                                              'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
                                               "fit_params":{}
                                              })
print(dml_estimate.cate_estimates)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
[[12.16905407]
 [10.42165624]
 [10.67733797]
 [11.62091416]
 [10.98568175]
 [12.5880907 ]
 [10.96991444]
 [10.46498503]
 [10.6584332 ]
 [11.35387743]]

我们可以取出原始 EconML estimator 对象以进行任何进一步的操作。

[16]:
print(dml_estimate._estimator_object)
dml_estimate
<econml.dml.DMLCateEstimator object at 0x1c29696390>
[16]:
<dowhy.causal_estimator.CausalEstimate at 0x1c299532d0>

Works with any EconML method

In addition to double machine learning, below we example analyses using orthogonal forests, DRLearner (bug to fix), and neural network-based instrumental variables.

除了 double machine learning 之外,below we example analyses using 正交森林,DRLearner (bug to fix) 和基于神经网络的工具变量法等。

Continuous treatment, Continuous outcome

正交森林方法, orthogonal forests

[17]:
from sklearn.linear_model import LogisticRegression
orthoforest_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.ortho_forest.ContinuousTreatmentOrthoForest",
                                 target_units = lambda df: df["X0"]>1,
                                 confidence_intervals=False,
                                method_params={"init_params":{
                                                    'n_trees':2, # not ideal, just as an example to speed up computation
                                                    },
                                               "fit_params":{}
                                              })
print(orthoforest_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:   14.3s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:   14.3s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:   13.9s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:   13.9s finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    5.6s
[Parallel(n_jobs=-1)]: Done 120 tasks      | elapsed:   28.4s
[Parallel(n_jobs=-1)]: Done 280 tasks      | elapsed:  1.1min
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 12.077425164170238

[Parallel(n_jobs=-1)]: Done 465 out of 465 | elapsed:  1.8min finished

Binary treatment, Binary outcome

DRLearner estimator

[18]:
data_binary = dowhy.datasets.linear_dataset(10, num_common_causes=4, num_samples=10000,
                                    num_instruments=2, num_effect_modifiers=2,
                                    treatment_is_binary=True, outcome_is_binary=True)
# convert boolean values to {0,1} numeric
data_binary['df'].v0 = data_binary['df'].v0.astype(int)
data_binary['df'].y = data_binary['df'].y.astype(int)
print(data_binary['df'])

model_binary = CausalModel(data=data_binary["df"],
                    treatment=data_binary["treatment_name"], outcome=data_binary["outcome_name"],
                    graph=data_binary["gml_graph"])
identified_estimand_binary = model_binary.identify_effect(proceed_when_unidentifiable=True)
INFO:dowhy.causal_model:Model to find the causal effect of treatment ['v0'] on outcome ['y']
INFO:dowhy.causal_identifier:Common causes of treatment and outcome:['Unobserved Confounders', 'W0', 'W1', 'W3', 'W2']
WARNING:dowhy.causal_identifier:If this is observed data (not from a randomized experiment), there might always be missing confounders. Causal effect cannot be identified perfectly.
INFO:dowhy.causal_identifier:Continuing by ignoring these unobserved confounders because proceed_when_unidentifiable flag is True.
INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:['Z0', 'Z1']
            X0        X1   Z0        Z1        W0        W1        W2  \
0     0.143737  0.895689  1.0  0.633470  1.228916  1.122003 -0.103196
1     0.027461 -0.842702  1.0  0.395361  1.830941  0.664096 -0.069751
2    -1.092524  0.863230  1.0  0.575361  2.094211  0.111609 -0.890925
3     0.226557  2.270647  0.0  0.779255  0.270405  0.359455  0.600549
4    -1.279422  0.610456  0.0  0.920709  1.630068  3.144059 -1.046812
...        ...       ...  ...       ...       ...       ...       ...
9995  1.920812  0.445518  1.0  0.423858  1.369760  0.553453  0.824119
9996  0.867758  0.790538  1.0  0.937978  2.061503  1.310532 -0.221242
9997  1.106002 -0.406379  0.0  0.609452  0.639231 -1.214904  1.988916
9998  0.899491 -0.195329  0.0  0.872710  1.686182  0.795478 -2.904364
9999  1.099352  0.960848  1.0  0.248887  0.465099  1.360711 -2.091084

            W3  v0  y
0    -0.759881   1  1
1    -1.512547   1  0
2    -0.388786   1  1
3     1.460291   1  1
4    -0.522064   1  0
...        ...  .. ..
9995  0.964942   1  1
9996  1.333964   1  0
9997  0.860463   1  1
9998  1.887952   1  1
9999  0.994419   1  1

[10000 rows x 10 columns]

使用 DRLearner estimator

[19]:
from sklearn.linear_model import LogisticRegressionCV
#todo needs binary y
drlearner_estimate = model_binary.estimate_effect(identified_estimand_binary,
                                method_name="backdoor.econml.drlearner.LinearDRLearner",
                                target_units = lambda df: df["X0"]>1,
                                confidence_intervals=False,
                                method_params={"init_params":{
                                                    'model_propensity': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto')
                                                    },
                                               "fit_params":{}
                                              })
print(drlearner_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 0.17287242432807604

工具变量法

[20]:
import keras
from econml.deepiv import DeepIVEstimator
dims_zx = len(model._instruments)+len(model._effect_modifiers)
dims_tx = len(model._treatment)+len(model._effect_modifiers)
treatment_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(dims_zx,)), # sum of dims of Z and X
                                    keras.layers.Dropout(0.17),
                                    keras.layers.Dense(64, activation='relu'),
                                    keras.layers.Dropout(0.17),
                                    keras.layers.Dense(32, activation='relu'),
                                    keras.layers.Dropout(0.17)])
response_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(dims_tx,)), # sum of dims of T and X
                                    keras.layers.Dropout(0.17),
                                    keras.layers.Dense(64, activation='relu'),
                                    keras.layers.Dropout(0.17),
                                    keras.layers.Dense(32, activation='relu'),
                                    keras.layers.Dropout(0.17),
                                    keras.layers.Dense(1)])

deepiv_estimate = model.estimate_effect(identified_estimand,
                                        method_name="iv.econml.deepiv.DeepIVEstimator",
                                        target_units = lambda df: df["X0"]>-1,
                                        confidence_intervals=False,
                                method_params={"init_params":{'n_components': 10, # Number of gaussians in the mixture density networks
                                                              'm': lambda z, x: treatment_model(keras.layers.concatenate([z, x])), # Treatment model,
                                                              "h": lambda t, x: response_model(keras.layers.concatenate([t, x])), # Response model
                                                              'n_samples': 1, # Number of samples used to estimate the response
                                                              'first_stage_options': {'epochs':25},
                                                              'second_stage_options': {'epochs':25}
                                                             },
                                               "fit_params":{}})
print(deepiv_estimate)
Using TensorFlow backend.
WARNING:tensorflow:From /Users/gong/opt/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
WARNING:tensorflow:From /Users/gong/opt/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py:2509: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
WARNING:tensorflow:From /Users/gong/opt/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

Epoch 1/25
10000/10000 [==============================] - 1s 136us/step - loss: 4.9846
Epoch 2/25
10000/10000 [==============================] - 1s 74us/step - loss: 2.8016
Epoch 3/25
10000/10000 [==============================] - 1s 74us/step - loss: 2.6807
Epoch 4/25
10000/10000 [==============================] - 1s 75us/step - loss: 2.6352
Epoch 5/25
10000/10000 [==============================] - 1s 86us/step - loss: 2.6195
Epoch 6/25
10000/10000 [==============================] - 1s 77us/step - loss: 2.5961
Epoch 7/25
10000/10000 [==============================] - 1s 78us/step - loss: 2.5861
Epoch 8/25
10000/10000 [==============================] - 1s 85us/step - loss: 2.5722
Epoch 9/25
10000/10000 [==============================] - 1s 75us/step - loss: 2.5690
Epoch 10/25
10000/10000 [==============================] - 1s 97us/step - loss: 2.5555
Epoch 11/25
10000/10000 [==============================] - 1s 89us/step - loss: 2.5486
Epoch 12/25
10000/10000 [==============================] - 1s 86us/step - loss: 2.5424
Epoch 13/25
10000/10000 [==============================] - 1s 92us/step - loss: 2.5449
Epoch 14/25
10000/10000 [==============================] - 1s 78us/step - loss: 2.5388
Epoch 15/25
10000/10000 [==============================] - 1s 75us/step - loss: 2.5307
Epoch 16/25
10000/10000 [==============================] - 1s 83us/step - loss: 2.5297
Epoch 17/25
10000/10000 [==============================] - 1s 86us/step - loss: 2.5225
Epoch 18/25
10000/10000 [==============================] - 1s 85us/step - loss: 2.5296
Epoch 19/25
10000/10000 [==============================] - 1s 105us/step - loss: 2.5237
Epoch 20/25
10000/10000 [==============================] - 1s 95us/step - loss: 2.5246: 0s - loss: 2.52
Epoch 21/25
10000/10000 [==============================] - 1s 95us/step - loss: 2.5184
Epoch 22/25
10000/10000 [==============================] - 1s 94us/step - loss: 2.5190
Epoch 23/25
10000/10000 [==============================] - 1s 79us/step - loss: 2.5183
Epoch 24/25
10000/10000 [==============================] - 1s 90us/step - loss: 2.5209
Epoch 25/25
10000/10000 [==============================] - 1s 102us/step - loss: 2.5137
Epoch 1/25
10000/10000 [==============================] - 2s 183us/step - loss: 10739.1212
Epoch 2/25
10000/10000 [==============================] - 1s 113us/step - loss: 9044.8074
Epoch 3/25
10000/10000 [==============================] - 1s 107us/step - loss: 8899.1196
Epoch 4/25
10000/10000 [==============================] - 1s 103us/step - loss: 8962.9156
Epoch 5/25
10000/10000 [==============================] - 1s 103us/step - loss: 8817.9901
Epoch 6/25
10000/10000 [==============================] - 1s 103us/step - loss: 8978.8951
Epoch 7/25
10000/10000 [==============================] - 1s 102us/step - loss: 8832.8224
Epoch 8/25
10000/10000 [==============================] - 1s 102us/step - loss: 8781.2165
Epoch 9/25
10000/10000 [==============================] - 1s 105us/step - loss: 8968.5057
Epoch 10/25
10000/10000 [==============================] - 1s 108us/step - loss: 8940.7467
Epoch 11/25
10000/10000 [==============================] - 1s 110us/step - loss: 8870.7342
Epoch 12/25
10000/10000 [==============================] - 1s 116us/step - loss: 8842.7259
Epoch 13/25
10000/10000 [==============================] - 1s 107us/step - loss: 8828.7600
Epoch 14/25
10000/10000 [==============================] - 1s 113us/step - loss: 8834.1166
Epoch 15/25
10000/10000 [==============================] - 1s 110us/step - loss: 8760.5785
Epoch 16/25
10000/10000 [==============================] - 1s 111us/step - loss: 8880.9786
Epoch 17/25
10000/10000 [==============================] - 1s 106us/step - loss: 8873.5939
Epoch 18/25
10000/10000 [==============================] - 1s 106us/step - loss: 8654.6864
Epoch 19/25
10000/10000 [==============================] - 1s 104us/step - loss: 8858.9382
Epoch 20/25
10000/10000 [==============================] - 1s 101us/step - loss: 8842.7146
Epoch 21/25
10000/10000 [==============================] - 1s 101us/step - loss: 8821.8297
Epoch 22/25
10000/10000 [==============================] - 1s 105us/step - loss: 8699.4650
Epoch 23/25
10000/10000 [==============================] - 1s 108us/step - loss: 8835.3129
Epoch 24/25
10000/10000 [==============================] - 1s 116us/step - loss: 8813.4877
Epoch 25/25
10000/10000 [==============================] - 1s 113us/step - loss: 8961.5180
*** Causal Estimate ***

## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
  d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)

## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 4.3953657150268555

Metalearners

[21]:
data_experiment = dowhy.datasets.linear_dataset(10, num_common_causes=0, num_samples=10000,
                                    num_instruments=2, num_effect_modifiers=4,
                                    treatment_is_binary=True, outcome_is_binary=True)
# convert boolean values to {0,1} numeric
data_experiment['df'].v0 = data_experiment['df'].v0.astype(int)
data_experiment['df'].y = data_experiment['df'].y.astype(int)
print(data_experiment['df'])

model_experiment = CausalModel(data=data_experiment["df"],
                    treatment=data_experiment["treatment_name"], outcome=data_experiment["outcome_name"],
                    graph=data_experiment["gml_graph"])
identified_estimand_experiment = model_experiment.identify_effect(proceed_when_unidentifiable=True)
INFO:dowhy.causal_model:Model to find the causal effect of treatment ['v0'] on outcome ['y']
INFO:dowhy.causal_identifier:Common causes of treatment and outcome:['Unobserved Confounders']
WARNING:dowhy.causal_identifier:If this is observed data (not from a randomized experiment), there might always be missing confounders. Causal effect cannot be identified perfectly.
INFO:dowhy.causal_identifier:Continuing by ignoring these unobserved confounders because proceed_when_unidentifiable flag is True.
INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:['Z0', 'Z1']
            X0        X1        X2        X3   Z0        Z1  v0  y
0     0.995605  2.013543  1.910889  0.501417  1.0  0.905808   1  0
1    -1.607226 -0.927225 -0.831809 -2.579346  0.0  0.003871   0  0
2    -1.542529 -2.111406  0.482695  0.669389  1.0  0.135426   1  0
3    -1.833528  0.065794 -1.490086 -1.895976  0.0  0.847716   1  0
4    -1.760129 -1.348228  2.084825  1.365359  0.0  0.248994   1  1
...        ...       ...       ...       ...  ...       ...  .. ..
9995 -0.355503 -1.634357  0.915420  0.189763  0.0  0.956708   1  1
9996 -0.077529 -1.898474 -0.963058  0.126657  1.0  0.794665   1  1
9997 -0.157876 -0.812651  2.003876 -0.850003  1.0  0.708076   1  1
9998 -2.744425 -3.038899 -0.061868  0.242013  0.0  0.792547   1  1
9999 -1.211644  0.422355  0.996183  0.798486  0.0  0.996503   1  0

[10000 rows x 8 columns]
[ ]:
from sklearn.linear_model import LogisticRegressionCV
metalearner_estimate = model_experiment.estimate_effect(identified_estimand_experiment,
                                method_name="backdoor.econml.metalearners.TLearner",
                                target_units = lambda df: df["X0"]>1,
                                confidence_intervals=False,
                                method_params={"init_params":{
                                                    'models': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto')
                                                    },
                                               "fit_params":{}
                                              })
print(metalearner_estimate)

Refuting the estimate

Random

[ ]:
res_random=model.refute_estimate(identified_estimand, dml_estimate, method_name="random_common_cause")
print(res_random)

Adding an unobserved common cause variable

[ ]:
res_unobserved=model.refute_estimate(identified_estimand, dml_estimate, method_name="add_unobserved_common_cause",
                                     confounders_effect_on_treatment="linear", confounders_effect_on_outcome="linear",
                                    effect_strength_on_treatment=0.01, effect_strength_on_outcome=0.02)
print(res_unobserved)

Replacing treatment with a random (placebo) variable

[ ]:
res_placebo=model.refute_estimate(identified_estimand, dml_estimate,
        method_name="placebo_treatment_refuter", placebo_type="permute")
print(res_placebo)

Removing a random subset of the data

[ ]:
res_subset=model.refute_estimate(identified_estimand, dml_estimate,
        method_name="data_subset_refuter", subset_fraction=0.8)
print(res_subset)

More refutation methods to come, especially specific to the CATE estimators.