用 DoWhy 和 EconML 估计条件平均因果效应¶
条件平均因果效应(CATE) with DoWhy and EconML
这是一项实验性功能 where we use EconML methods from DoWhy. Using EconML allows CATE estimation using different methods.
DoWhy中因果推理的所有四个步骤都保持不变:建模,识别,估计和反驳。 关键区别在于我们现在在估算步骤中调用econml方法。 还有一个使用线性回归的简单示例 to understand the intuition behind CATE estimators.
[1]:
import os, sys
sys.path.insert(1, os.path.abspath("../../../")) # for dowhy source code
[1]:
import numpy as np
import pandas as pd
import logging
import dowhy
from dowhy import CausalModel
import dowhy.datasets
import econml
import warnings
warnings.filterwarnings('ignore')
[5]:
data = dowhy.datasets.linear_dataset(10, num_common_causes=4, num_samples=10000,
num_instruments=2, num_effect_modifiers=2,
num_treatments=1,
treatment_is_binary=False,
# num_discrete_common_causes=2,
# num_discrete_effect_modifiers=0,
# one_hot_encode=False
)
df=data['df']
df.head()
[5]:
X0 | X1 | Z0 | Z1 | W0 | W1 | W2 | W3 | v0 | y | |
---|---|---|---|---|---|---|---|---|---|---|
0 | -2.521464 | 0.155409 | 1.0 | 0.559138 | 2.007452 | 0.194488 | -0.737834 | 0.949250 | 18.744430 | 144.319845 |
1 | -0.272124 | -0.594038 | 0.0 | 0.562656 | 2.412827 | -0.966005 | 1.208103 | -0.916247 | 5.594197 | 56.033465 |
2 | 1.471747 | 0.949394 | 0.0 | 0.537858 | 0.078511 | -0.334681 | 1.886087 | -0.414642 | 10.707244 | 143.527164 |
3 | -0.714296 | -1.157574 | 1.0 | 0.650554 | -0.396513 | 2.081661 | -1.689673 | 0.317751 | 16.975189 | 130.548563 |
4 | -0.895241 | 0.463742 | 1.0 | 0.773794 | -0.099374 | 0.707962 | 1.917918 | -2.526865 | 16.929208 | 165.811436 |
[7]:
model = CausalModel(data=data["df"],
treatment=data["treatment_name"], outcome=data["outcome_name"],
graph=data["gml_graph"])
model.view_model()
from IPython.display import Image, display
display(Image(filename="causal_model.png"))
INFO:dowhy.causal_model:Model to find the causal effect of treatment ['v0'] on outcome ['y']
[8]:
identified_estimand= model.identify_effect(proceed_when_unidentifiable=True)
print(identified_estimand)
INFO:dowhy.causal_identifier:Common causes of treatment and outcome:['Unobserved Confounders', 'W0', 'W1', 'W3', 'W2']
WARNING:dowhy.causal_identifier:If this is observed data (not from a randomized experiment), there might always be missing confounders. Causal effect cannot be identified perfectly.
INFO:dowhy.causal_identifier:Continuing by ignoring these unobserved confounders because proceed_when_unidentifiable flag is True.
INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:['Z0', 'Z1']
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
线性模型¶
首先,让我们使用线性模型建立一些直觉来估计CATE。可以将 effect modifiers(导致异质因果效应)建模为与 treatment 的交互项。因此,它们的值 modulates the effect of treatment.
Below the estimated effect of changing treatment from 0 to 1.
[9]:
linear_estimate = model.estimate_effect(identified_estimand,
method_name="backdoor.linear_regression",
control_value=0,
treatment_value=1)
print(linear_estimate)
INFO:dowhy.causal_estimator:INFO: Using Linear Regression Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2+v0*X1+v0*X0
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2+v0*X1+v0*X0
## Estimate
Value: 10.00000000000001
EconML 方法¶
现在,我们从EconML包转向更高级的方法来估算CATE。
首先,让我们看一下 double machine learning estimator。Method_name 对应于我们要使用的类的标准名称。对于 double ML,它是“ econml.dml.DMLCateEstimator”。
Target units 定义了要计算因果估计的 units。 可以是 a lambda function filter on the original dataframe, a new Pandas dataframe, or a string corresponding to the three main kinds of target units (“ate”, “att” and “atc”). 下面我们显示一个lambda函数的示例。
Method_params 是直接传参数给 EconML. 有关允许的参数的详细信息,请参阅EconML文档。
[10]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DMLCateEstimator",
control_value = 0,
treatment_value = 1,
target_units = lambda df: df["X0"]>1, # condition used for CATE
confidence_intervals=False,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{}})
print(dml_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
INFO:numexpr.utils:NumExpr defaulting to 4 threads.
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 12.377928464975215
[11]:
print("True causal estimate is", data["ate"])
True causal estimate is 9.978706628277308
[12]:
dml_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.dml.DMLCateEstimator",
control_value = 0,
treatment_value = 1,
target_units = 1, # condition used for CATE
confidence_intervals=False,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{}})
print(dml_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 9.919618287281304
CATE Object 和置信区间¶
[13]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand,
method_name="backdoor.econml.dml.DMLCateEstimator",
target_units = lambda df: df["X0"]>1,
confidence_intervals=True,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{
'inference': 'bootstrap',
}
})
print(dml_estimate)
print(dml_estimate.cate_estimates[:10])
print(dml_estimate.effect_intervals)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done 24 tasks | elapsed: 17.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 1.2min finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done 24 tasks | elapsed: 0.0s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 0.1s finished
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 12.29197180689786
[[13.07887647]
[10.77809375]
[13.71470538]
[12.2937543 ]
[11.72053516]
[10.7891811 ]
[12.77571337]
[13.15299296]
[12.73256256]
[ 9.99162377]]
(array([[12.87146448],
[10.63969365],
[13.49162136],
[12.10427793],
[11.55423673],
[10.65016767],
[12.53653385],
[12.94199474],
[12.51816049],
[ 9.88026419],
[11.87817487],
[11.71871174],
[ 9.4148536 ],
[14.34614983],
[10.46412705],
[12.56785538],
[13.83984034],
[13.4679312 ],
[14.35042059],
[12.35224681],
[11.61582257],
[10.29811664],
[13.59903865],
[ 8.55052931],
[13.45880418],
[13.77643288],
[12.19587955],
[10.83290248],
[10.7777583 ],
[13.94236224],
[13.01055602],
[12.50870928],
[14.25350222],
[11.481481 ],
[13.74603396],
[10.69244058],
[13.45605579],
[11.66522 ],
[14.25761843],
[11.36297136],
[14.88296605],
[13.55224443],
[14.17487503],
[11.1713894 ],
[10.68895161],
[15.59339931],
[10.93335862],
[11.61122488],
[13.07660271],
[13.48912346],
[10.73406571],
[10.62513206],
[13.56689894],
[12.15347773],
[13.92805453],
[10.76597466],
[12.465917 ],
[13.00804687],
[13.78734038],
[11.24303223],
[13.34725069],
[14.0715816 ],
[10.90278155],
[11.42205258],
[11.32769271],
[ 9.86228705],
[11.39895884],
[12.73089948],
[11.50841076],
[13.59595896],
[14.09253805],
[ 8.55421913],
[12.43520073],
[13.38545776],
[ 9.92893184],
[14.1753038 ],
[10.10168566],
[12.11887461],
[10.64007044],
[10.66021766],
[14.03788148],
[10.2426245 ],
[13.48656247],
[10.28762964],
[13.32024058],
[11.18720081],
[13.17964278],
[12.20751813],
[13.27962142],
[11.07746677],
[13.96154738],
[10.52742145],
[14.66574205],
[11.45737787],
[12.13263397],
[ 7.28664916],
[ 9.1369182 ],
[12.10918153],
[12.58778935],
[10.72293532],
[13.5342271 ],
[10.49410814],
[ 9.33320491],
[12.45948473],
[11.45174567],
[10.24555438],
[13.62177991],
[ 9.9130219 ],
[14.34600261],
[14.43809285],
[ 9.66878299],
[10.07858596],
[12.05202067],
[12.70471663],
[12.0662915 ],
[10.51067217],
[ 8.8605647 ],
[ 9.38853094],
[10.79353383],
[13.65557565],
[11.48356879],
[10.45198197],
[13.42241315],
[11.64821402],
[12.32981612],
[12.3897112 ],
[12.78012245],
[ 9.3541085 ],
[12.39124123],
[11.49473942],
[15.60255523],
[11.03254036],
[14.48455651],
[12.42579355],
[13.89661424],
[10.86360454],
[11.57337619],
[13.54398826],
[11.89090609],
[10.84262207],
[13.54875206],
[11.57927462],
[12.50655234],
[13.27503301],
[12.27626145],
[10.89800332],
[13.80514938],
[10.56707828],
[ 9.66292243],
[13.79325233],
[13.05331069],
[13.08415392],
[ 9.88640943],
[11.1954592 ],
[11.67724679],
[14.79144052],
[13.73341175],
[10.29388788],
[13.25140371],
[10.18888078],
[12.00530852],
[11.13523626],
[11.42142756],
[13.00333764],
[10.55216177],
[13.7069508 ],
[ 9.939213 ],
[ 8.44011804],
[11.36589106],
[12.67601568],
[12.7563392 ],
[15.03531603],
[14.06365864],
[13.64215571],
[12.34555147],
[10.77237273],
[10.44920636],
[13.75170402],
[13.22655426],
[11.97592994],
[11.06976618],
[10.6006746 ],
[14.70619444],
[12.53276561],
[ 9.77386486],
[10.24176779],
[14.65933942],
[13.23228319],
[13.41149638],
[15.63274955],
[11.54459279],
[12.16636735],
[ 9.57313977],
[ 9.68443792],
[11.92627929],
[10.75033796],
[13.75137642],
[ 9.37851181],
[13.59691607],
[11.99425813],
[13.77435554],
[12.32905697],
[10.0170267 ],
[11.98794031],
[12.3031148 ],
[11.46157776],
[16.51520556],
[12.82192023],
[11.48265185],
[12.23133619],
[12.41014321],
[13.41517545],
[10.00409418],
[12.93733235],
[10.92526768],
[11.21242851],
[11.76314686],
[13.48521122],
[12.84019928],
[10.52288634],
[13.84317597],
[11.52142684],
[11.18409827],
[12.50920479],
[11.56451181],
[11.77683623],
[12.00109072],
[10.36001371],
[13.97656052],
[13.72559367],
[ 8.93137977],
[11.50617527],
[11.17749833],
[ 9.66613345],
[11.83224923],
[10.41976534],
[13.39673001],
[12.29996792],
[12.58169478],
[11.88588282],
[ 8.56474626],
[13.35990226],
[11.12620209],
[ 9.97506967],
[11.82271745],
[11.35257882],
[13.33420405],
[11.69218399],
[12.69285982],
[11.43339827],
[11.56324602],
[11.90482041],
[13.49130632],
[12.27140265],
[14.18647427],
[14.03699062],
[13.74436576],
[11.53214309],
[11.05060965],
[12.71232129],
[12.10131898],
[12.95535412],
[ 8.28023453],
[12.65453404],
[13.01294309],
[11.49852355],
[12.81940313],
[12.36242016],
[14.55547993],
[12.94852281],
[12.09273988],
[12.59353021],
[11.24883235],
[ 9.72728484],
[14.56589286],
[11.71223304],
[13.38675392],
[11.35694388],
[ 9.16338904],
[11.33505823],
[ 9.6709106 ],
[11.37968799],
[12.27489278],
[ 9.20931755],
[15.72028541],
[12.6571529 ],
[14.12789102],
[13.78419097],
[13.94380542],
[11.41614025],
[11.91983963],
[11.95342121],
[11.78427156],
[10.53913693],
[12.61911117],
[13.0776739 ],
[12.99906085],
[13.16019801],
[ 9.43333915],
[14.1179648 ],
[13.68698569],
[10.79001288],
[11.45795032],
[10.87676189],
[ 9.70138916],
[12.30902341],
[11.02233042],
[11.11225944],
[10.59053996],
[13.67246118],
[13.01201609],
[12.4485459 ],
[12.14441008],
[14.65459005],
[14.2205289 ],
[14.33755189],
[ 9.76317934],
[10.39151872],
[14.06892891],
[11.69943786],
[14.13612428],
[10.31022793],
[13.59109518],
[12.43317473],
[12.73157091],
[11.7809359 ],
[12.05809952],
[13.57855211],
[11.81946644],
[12.03365272],
[ 9.69284722],
[12.61077958],
[10.1974851 ],
[10.69373057],
[10.14790369],
[ 9.5376557 ],
[13.62622577],
[13.69692541],
[12.96195904],
[13.68401556],
[11.64764038],
[13.35338454],
[12.2906384 ],
[13.63119092],
[11.84476677],
[ 9.77711082],
[16.14865539],
[12.66458613],
[13.91513767],
[10.89029429],
[12.15157721],
[13.60672934],
[ 9.85200774],
[15.24536968],
[ 8.75469341],
[12.56884484],
[11.59634557],
[11.83788921],
[11.94704183],
[10.86275636],
[14.29852981],
[13.49962816],
[12.44167422],
[14.78408918],
[12.69767754],
[13.62519429],
[14.36100415],
[14.37605106],
[11.5872207 ],
[12.22068667],
[14.75781157],
[12.51409019],
[10.86837543],
[12.46988342],
[12.61025858],
[11.24015865],
[11.18835223],
[10.79716704],
[11.16863208],
[17.8891018 ],
[ 9.02508256],
[10.93685663],
[11.68417659],
[13.13925115],
[12.59156738],
[10.55995775],
[14.05418231],
[10.10209501],
[ 9.6423258 ],
[12.97532986],
[12.29441667],
[12.55389477],
[10.68272711],
[10.62653066],
[13.6699227 ],
[10.52616547],
[11.70239017],
[10.14030339],
[13.13935322],
[11.95437062],
[12.17341794],
[12.91915496],
[11.32004661],
[11.64838252],
[12.56425687],
[11.77844128],
[14.65835339],
[ 9.70707398],
[14.9300481 ],
[12.41583695],
[10.51673836],
[15.44266874],
[10.29119328],
[12.4698904 ],
[11.66472781],
[13.25203398],
[14.60736181],
[11.7148831 ],
[ 8.71781083],
[11.67637186],
[14.19799786],
[11.69866959],
[ 9.92887925],
[ 9.71813268],
[12.02030602],
[12.41740233],
[10.47762016],
[12.20119777],
[10.48521344],
[10.67327157],
[11.87827654],
[11.58606279],
[ 8.9557734 ],
[12.64025621],
[11.79850084],
[14.84465251],
[11.85336354],
[15.0924792 ],
[11.49130477],
[13.30705909],
[10.28816464],
[10.83173928],
[10.90623076],
[12.63904975],
[11.19355693],
[14.31466264],
[11.92317275],
[12.35268268],
[11.0891759 ],
[12.68793201],
[12.61899466],
[12.5101173 ],
[12.87242751],
[14.39594312],
[12.5929976 ],
[14.19507331],
[11.5081742 ],
[11.09412734],
[10.90389304],
[12.19205614],
[14.54447031],
[12.03885945],
[11.07151343],
[12.15146514],
[13.4617387 ]]), array([[13.07894792],
[10.84575642],
[13.69590258],
[12.31232205],
[11.75347935],
[10.85714818],
[12.84426434],
[13.1444251 ],
[12.76482436],
[10.06939464],
[12.06387706],
[11.89619409],
[ 9.61847512],
[14.58858477],
[10.63703602],
[12.74834127],
[14.07150676],
[13.69165376],
[14.59545699],
[12.6189523 ],
[11.77948879],
[10.55910703],
[13.81663594],
[ 8.80099987],
[13.66111391],
[13.99053663],
[12.40390807],
[11.01583279],
[10.94622089],
[14.1953548 ],
[13.22704109],
[12.69790294],
[14.48992041],
[11.71906794],
[13.98686407],
[10.91538098],
[13.66939961],
[11.91477002],
[14.49411742],
[11.59968286],
[15.17467009],
[13.78771275],
[14.40876507],
[11.38332409],
[10.86291076],
[15.91855765],
[11.12202599],
[11.8139945 ],
[13.2785676 ],
[13.75410866],
[10.94500534],
[10.81665283],
[13.77728066],
[12.35437864],
[14.15992469],
[10.96932092],
[12.65282316],
[13.20062563],
[14.05003482],
[11.40968384],
[13.54395628],
[14.32499006],
[11.10612758],
[11.653047 ],
[11.5233189 ],
[10.13225458],
[11.57943789],
[12.96977652],
[11.67634853],
[13.80597741],
[14.32265421],
[ 8.78074485],
[12.62539097],
[13.62372143],
[10.13220664],
[14.41165466],
[10.30313441],
[12.29137817],
[10.85908844],
[10.86997202],
[14.31642074],
[10.45987268],
[13.69175387],
[10.49835313],
[13.52140227],
[11.40790524],
[13.40323376],
[12.38705698],
[13.47856897],
[11.25691355],
[14.19044117],
[10.70758337],
[14.92757739],
[11.62831345],
[12.30570674],
[ 7.64964432],
[ 9.34532699],
[12.28128793],
[12.77242366],
[10.91473451],
[13.7431299 ],
[10.66586304],
[ 9.60098011],
[12.63828052],
[11.61905829],
[10.452038 ],
[13.83878977],
[10.16672066],
[14.5952389 ],
[14.69404573],
[ 9.89345553],
[10.27842252],
[12.22190878],
[12.90384068],
[12.23505028],
[10.71683181],
[ 9.17537065],
[ 9.61016646],
[10.95801958],
[13.87546345],
[11.64645976],
[10.65094227],
[13.65933812],
[11.83807144],
[12.5110824 ],
[12.56438347],
[12.98134699],
[ 9.55618822],
[12.57505298],
[11.70467296],
[15.9181746 ],
[11.20365837],
[14.72973041],
[12.61204303],
[14.1176988 ],
[11.02560145],
[11.73771415],
[13.75031897],
[12.05927797],
[11.06983217],
[13.75961738],
[11.74870414],
[12.68565666],
[13.47632584],
[12.45289081],
[11.11247564],
[14.0316403 ],
[10.73797911],
[ 9.85968283],
[14.01599045],
[13.26567286],
[13.28723205],
[10.12892574],
[11.36787909],
[11.96631688],
[15.07638762],
[13.95008643],
[10.47858977],
[13.47164153],
[10.37580755],
[12.19070515],
[11.29649832],
[11.62707912],
[13.19232037],
[10.72747008],
[13.92321316],
[10.1347228 ],
[ 8.67952567],
[11.52706188],
[12.86262296],
[13.00120632],
[15.34565835],
[14.33304574],
[13.88104819],
[12.58757772],
[10.94205099],
[10.63880514],
[13.98795285],
[13.43748301],
[12.14712398],
[11.23981735],
[10.77813238],
[14.98107644],
[12.75519106],
[10.0270005 ],
[10.46155806],
[14.91236854],
[13.42385938],
[13.61151155],
[15.93569038],
[11.70527232],
[12.36329822],
[ 9.79087368],
[ 9.89006077],
[12.10639068],
[10.91767124],
[13.97416093],
[ 9.63250623],
[13.82480848],
[12.16197299],
[13.98801989],
[12.52974574],
[10.19889155],
[12.15568266],
[12.48171293],
[11.62335875],
[16.86451547],
[13.01885299],
[11.68630172],
[12.4091232 ],
[12.59474146],
[13.61374044],
[10.18719813],
[13.14822494],
[11.10638275],
[11.37706048],
[11.92855488],
[13.70415573],
[13.06224186],
[10.7037102 ],
[14.07465468],
[11.6932775 ],
[11.39504106],
[12.68941289],
[11.73351064],
[11.94484994],
[12.19332483],
[10.55123003],
[14.25136772],
[13.94359492],
[ 9.24508724],
[11.73317058],
[11.34539596],
[ 9.86223715],
[12.0394726 ],
[10.61688767],
[13.59578139],
[12.47307441],
[12.84799256],
[12.12199599],
[ 8.82787628],
[13.59190191],
[11.30036597],
[10.17103972],
[12.00173376],
[11.56842465],
[13.56036342],
[11.88751906],
[12.89196416],
[11.59627951],
[11.73206905],
[12.07283589],
[13.73674178],
[12.4633716 ],
[14.42055929],
[14.28186157],
[13.96295546],
[11.79873451],
[11.36397286],
[12.91362479],
[12.28429717],
[13.14080192],
[ 8.53768983],
[12.84093397],
[13.22400308],
[11.6941683 ],
[13.02647421],
[12.5470457 ],
[14.80662333],
[13.13563831],
[12.26415901],
[12.78020333],
[11.44372489],
[ 9.91833421],
[14.81724042],
[11.94907772],
[13.61663202],
[11.52133777],
[ 9.36330527],
[11.50277897],
[ 9.86089524],
[11.59119489],
[12.45387866],
[ 9.45278179],
[16.05417394],
[12.85200218],
[14.38519643],
[13.99780066],
[14.17443287],
[11.60165443],
[12.11103529],
[12.15113256],
[11.95415841],
[10.71023757],
[12.79762825],
[13.2981882 ],
[13.20891865],
[13.42377608],
[ 9.64827348],
[14.35456871],
[13.92566437],
[10.99939327],
[11.65528914],
[11.04994846],
[ 9.92106329],
[12.48806927],
[11.18931425],
[11.33297902],
[10.77017194],
[13.91126122],
[13.2107673 ],
[12.63473459],
[12.33326877],
[14.98380379],
[14.46041511],
[14.58334991],
[ 9.95281541],
[10.59463414],
[14.30753279],
[11.87654692],
[14.3931904 ],
[10.55168569],
[13.80315741],
[12.61872836],
[12.92412375],
[11.94505195],
[12.25156781],
[13.8079198 ],
[12.02113797],
[12.2489677 ],
[ 9.88561159],
[12.80330951],
[10.39572344],
[10.9004194 ],
[10.32944656],
[ 9.76611252],
[13.86443622],
[13.92716434],
[13.16217697],
[13.89469329],
[11.83208676],
[13.58327243],
[12.46610004],
[13.85931052],
[12.09642662],
[ 9.97731336],
[16.50113118],
[12.9005419 ],
[14.14390891],
[11.09512146],
[12.32543621],
[13.82527797],
[10.05273339],
[15.53150804],
[ 8.97826605],
[12.74632713],
[11.80154538],
[12.01529029],
[12.11330114],
[11.06446525],
[14.53694978],
[13.75575814],
[12.62757453],
[15.04564175],
[12.88230155],
[13.86984434],
[14.65061708],
[14.64039702],
[11.75746314],
[12.39670343],
[15.01843328],
[12.73950981],
[11.0456934 ],
[12.64732602],
[12.91153646],
[11.41520105],
[11.40522134],
[10.98135761],
[11.34642021],
[18.30848949],
[ 9.24009505],
[11.10670151],
[11.84699945],
[13.35522616],
[12.94369597],
[10.75195423],
[14.29798799],
[10.3404489 ],
[ 9.84047489],
[13.18302594],
[12.46754479],
[12.77318462],
[10.88590586],
[10.8146032 ],
[13.88969196],
[10.69894254],
[11.86616752],
[10.31778116],
[13.33549707],
[12.12081293],
[12.34818768],
[13.21361609],
[11.55605457],
[11.81496619],
[12.77248692],
[11.94411167],
[14.91386647],
[ 9.90423407],
[15.21423249],
[12.59001717],
[10.76051269],
[15.74643183],
[10.50964235],
[12.65098043],
[11.91006861],
[13.50735731],
[14.86053726],
[11.90193117],
[ 8.94482444],
[11.84485628],
[14.43969222],
[11.86809715],
[10.12221182],
[ 9.96428682],
[12.19763954],
[12.60229413],
[10.78121232],
[12.45916522],
[10.68948636],
[10.85049627],
[12.04439998],
[11.76811358],
[ 9.19299667],
[12.83913671],
[12.01661934],
[15.12081177],
[12.02667628],
[15.36597381],
[11.67436218],
[13.54464104],
[10.46902241],
[11.02452611],
[11.10377992],
[12.83162831],
[11.41408326],
[14.56105048],
[12.12005516],
[12.54298625],
[11.33483039],
[12.88319648],
[12.86458803],
[12.69920402],
[13.05534857],
[14.67111915],
[12.7846298 ],
[14.45120732],
[11.67365969],
[11.2566398 ],
[11.10900855],
[12.36760447],
[14.7956271 ],
[12.25965015],
[11.27881007],
[12.39934097],
[13.70644115]]))
New inputs 的因果效应¶
Can provide a new inputs as target units and estimate CATE on them.
[14]:
test_cols= data['effect_modifier_names'] # only need effect modifiers' values
test_arr = [np.random.uniform(0,1, 10) for _ in range(len(test_cols))] # all variables are sampled uniformly, sample of 10
test_df = pd.DataFrame(np.array(test_arr).transpose(), columns=test_cols)
dml_estimate = model.estimate_effect(identified_estimand,
method_name="backdoor.econml.dml.DMLCateEstimator",
target_units = test_df,
confidence_intervals=False,
method_params={"init_params":{'model_y':GradientBoostingRegressor(),
'model_t': GradientBoostingRegressor(),
"model_final":LassoCV(),
'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
"fit_params":{}
})
print(dml_estimate.cate_estimates)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
[[12.16905407]
[10.42165624]
[10.67733797]
[11.62091416]
[10.98568175]
[12.5880907 ]
[10.96991444]
[10.46498503]
[10.6584332 ]
[11.35387743]]
我们可以取出原始 EconML estimator 对象以进行任何进一步的操作。
[16]:
print(dml_estimate._estimator_object)
dml_estimate
<econml.dml.DMLCateEstimator object at 0x1c29696390>
[16]:
<dowhy.causal_estimator.CausalEstimate at 0x1c299532d0>
Works with any EconML method¶
In addition to double machine learning, below we example analyses using orthogonal forests, DRLearner (bug to fix), and neural network-based instrumental variables.
除了 double machine learning 之外,below we example analyses using 正交森林,DRLearner (bug to fix) 和基于神经网络的工具变量法等。
Continuous treatment, Continuous outcome¶
正交森林方法, orthogonal forests
[17]:
from sklearn.linear_model import LogisticRegression
orthoforest_estimate = model.estimate_effect(identified_estimand, method_name="backdoor.econml.ortho_forest.ContinuousTreatmentOrthoForest",
target_units = lambda df: df["X0"]>1,
confidence_intervals=False,
method_params={"init_params":{
'n_trees':2, # not ideal, just as an example to speed up computation
},
"fit_params":{}
})
print(orthoforest_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done 2 out of 2 | elapsed: 14.3s remaining: 0.0s
[Parallel(n_jobs=-1)]: Done 2 out of 2 | elapsed: 14.3s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done 2 out of 2 | elapsed: 13.9s remaining: 0.0s
[Parallel(n_jobs=-1)]: Done 2 out of 2 | elapsed: 13.9s finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done 24 tasks | elapsed: 5.6s
[Parallel(n_jobs=-1)]: Done 120 tasks | elapsed: 28.4s
[Parallel(n_jobs=-1)]: Done 280 tasks | elapsed: 1.1min
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 12.077425164170238
[Parallel(n_jobs=-1)]: Done 465 out of 465 | elapsed: 1.8min finished
Binary treatment, Binary outcome¶
DRLearner estimator
[18]:
data_binary = dowhy.datasets.linear_dataset(10, num_common_causes=4, num_samples=10000,
num_instruments=2, num_effect_modifiers=2,
treatment_is_binary=True, outcome_is_binary=True)
# convert boolean values to {0,1} numeric
data_binary['df'].v0 = data_binary['df'].v0.astype(int)
data_binary['df'].y = data_binary['df'].y.astype(int)
print(data_binary['df'])
model_binary = CausalModel(data=data_binary["df"],
treatment=data_binary["treatment_name"], outcome=data_binary["outcome_name"],
graph=data_binary["gml_graph"])
identified_estimand_binary = model_binary.identify_effect(proceed_when_unidentifiable=True)
INFO:dowhy.causal_model:Model to find the causal effect of treatment ['v0'] on outcome ['y']
INFO:dowhy.causal_identifier:Common causes of treatment and outcome:['Unobserved Confounders', 'W0', 'W1', 'W3', 'W2']
WARNING:dowhy.causal_identifier:If this is observed data (not from a randomized experiment), there might always be missing confounders. Causal effect cannot be identified perfectly.
INFO:dowhy.causal_identifier:Continuing by ignoring these unobserved confounders because proceed_when_unidentifiable flag is True.
INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:['Z0', 'Z1']
X0 X1 Z0 Z1 W0 W1 W2 \
0 0.143737 0.895689 1.0 0.633470 1.228916 1.122003 -0.103196
1 0.027461 -0.842702 1.0 0.395361 1.830941 0.664096 -0.069751
2 -1.092524 0.863230 1.0 0.575361 2.094211 0.111609 -0.890925
3 0.226557 2.270647 0.0 0.779255 0.270405 0.359455 0.600549
4 -1.279422 0.610456 0.0 0.920709 1.630068 3.144059 -1.046812
... ... ... ... ... ... ... ...
9995 1.920812 0.445518 1.0 0.423858 1.369760 0.553453 0.824119
9996 0.867758 0.790538 1.0 0.937978 2.061503 1.310532 -0.221242
9997 1.106002 -0.406379 0.0 0.609452 0.639231 -1.214904 1.988916
9998 0.899491 -0.195329 0.0 0.872710 1.686182 0.795478 -2.904364
9999 1.099352 0.960848 1.0 0.248887 0.465099 1.360711 -2.091084
W3 v0 y
0 -0.759881 1 1
1 -1.512547 1 0
2 -0.388786 1 1
3 1.460291 1 1
4 -0.522064 1 0
... ... .. ..
9995 0.964942 1 1
9996 1.333964 1 0
9997 0.860463 1 1
9998 1.887952 1 1
9999 0.994419 1 1
[10000 rows x 10 columns]
使用 DRLearner estimator
[19]:
from sklearn.linear_model import LogisticRegressionCV
#todo needs binary y
drlearner_estimate = model_binary.estimate_effect(identified_estimand_binary,
method_name="backdoor.econml.drlearner.LinearDRLearner",
target_units = lambda df: df["X0"]>1,
confidence_intervals=False,
method_params={"init_params":{
'model_propensity': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto')
},
"fit_params":{}
})
print(drlearner_estimate)
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 0.17287242432807604
工具变量法¶
[20]:
import keras
from econml.deepiv import DeepIVEstimator
dims_zx = len(model._instruments)+len(model._effect_modifiers)
dims_tx = len(model._treatment)+len(model._effect_modifiers)
treatment_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(dims_zx,)), # sum of dims of Z and X
keras.layers.Dropout(0.17),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dropout(0.17),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dropout(0.17)])
response_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(dims_tx,)), # sum of dims of T and X
keras.layers.Dropout(0.17),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dropout(0.17),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dropout(0.17),
keras.layers.Dense(1)])
deepiv_estimate = model.estimate_effect(identified_estimand,
method_name="iv.econml.deepiv.DeepIVEstimator",
target_units = lambda df: df["X0"]>-1,
confidence_intervals=False,
method_params={"init_params":{'n_components': 10, # Number of gaussians in the mixture density networks
'm': lambda z, x: treatment_model(keras.layers.concatenate([z, x])), # Treatment model,
"h": lambda t, x: response_model(keras.layers.concatenate([t, x])), # Response model
'n_samples': 1, # Number of samples used to estimate the response
'first_stage_options': {'epochs':25},
'second_stage_options': {'epochs':25}
},
"fit_params":{}})
print(deepiv_estimate)
Using TensorFlow backend.
WARNING:tensorflow:From /Users/gong/opt/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:dowhy.causal_estimator:INFO: Using EconML Estimator
INFO:dowhy.causal_estimator:b: y~v0+W0+W1+W3+W2
WARNING:tensorflow:From /Users/gong/opt/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py:2509: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
WARNING:tensorflow:From /Users/gong/opt/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
Epoch 1/25
10000/10000 [==============================] - 1s 136us/step - loss: 4.9846
Epoch 2/25
10000/10000 [==============================] - 1s 74us/step - loss: 2.8016
Epoch 3/25
10000/10000 [==============================] - 1s 74us/step - loss: 2.6807
Epoch 4/25
10000/10000 [==============================] - 1s 75us/step - loss: 2.6352
Epoch 5/25
10000/10000 [==============================] - 1s 86us/step - loss: 2.6195
Epoch 6/25
10000/10000 [==============================] - 1s 77us/step - loss: 2.5961
Epoch 7/25
10000/10000 [==============================] - 1s 78us/step - loss: 2.5861
Epoch 8/25
10000/10000 [==============================] - 1s 85us/step - loss: 2.5722
Epoch 9/25
10000/10000 [==============================] - 1s 75us/step - loss: 2.5690
Epoch 10/25
10000/10000 [==============================] - 1s 97us/step - loss: 2.5555
Epoch 11/25
10000/10000 [==============================] - 1s 89us/step - loss: 2.5486
Epoch 12/25
10000/10000 [==============================] - 1s 86us/step - loss: 2.5424
Epoch 13/25
10000/10000 [==============================] - 1s 92us/step - loss: 2.5449
Epoch 14/25
10000/10000 [==============================] - 1s 78us/step - loss: 2.5388
Epoch 15/25
10000/10000 [==============================] - 1s 75us/step - loss: 2.5307
Epoch 16/25
10000/10000 [==============================] - 1s 83us/step - loss: 2.5297
Epoch 17/25
10000/10000 [==============================] - 1s 86us/step - loss: 2.5225
Epoch 18/25
10000/10000 [==============================] - 1s 85us/step - loss: 2.5296
Epoch 19/25
10000/10000 [==============================] - 1s 105us/step - loss: 2.5237
Epoch 20/25
10000/10000 [==============================] - 1s 95us/step - loss: 2.5246: 0s - loss: 2.52
Epoch 21/25
10000/10000 [==============================] - 1s 95us/step - loss: 2.5184
Epoch 22/25
10000/10000 [==============================] - 1s 94us/step - loss: 2.5190
Epoch 23/25
10000/10000 [==============================] - 1s 79us/step - loss: 2.5183
Epoch 24/25
10000/10000 [==============================] - 1s 90us/step - loss: 2.5209
Epoch 25/25
10000/10000 [==============================] - 1s 102us/step - loss: 2.5137
Epoch 1/25
10000/10000 [==============================] - 2s 183us/step - loss: 10739.1212
Epoch 2/25
10000/10000 [==============================] - 1s 113us/step - loss: 9044.8074
Epoch 3/25
10000/10000 [==============================] - 1s 107us/step - loss: 8899.1196
Epoch 4/25
10000/10000 [==============================] - 1s 103us/step - loss: 8962.9156
Epoch 5/25
10000/10000 [==============================] - 1s 103us/step - loss: 8817.9901
Epoch 6/25
10000/10000 [==============================] - 1s 103us/step - loss: 8978.8951
Epoch 7/25
10000/10000 [==============================] - 1s 102us/step - loss: 8832.8224
Epoch 8/25
10000/10000 [==============================] - 1s 102us/step - loss: 8781.2165
Epoch 9/25
10000/10000 [==============================] - 1s 105us/step - loss: 8968.5057
Epoch 10/25
10000/10000 [==============================] - 1s 108us/step - loss: 8940.7467
Epoch 11/25
10000/10000 [==============================] - 1s 110us/step - loss: 8870.7342
Epoch 12/25
10000/10000 [==============================] - 1s 116us/step - loss: 8842.7259
Epoch 13/25
10000/10000 [==============================] - 1s 107us/step - loss: 8828.7600
Epoch 14/25
10000/10000 [==============================] - 1s 113us/step - loss: 8834.1166
Epoch 15/25
10000/10000 [==============================] - 1s 110us/step - loss: 8760.5785
Epoch 16/25
10000/10000 [==============================] - 1s 111us/step - loss: 8880.9786
Epoch 17/25
10000/10000 [==============================] - 1s 106us/step - loss: 8873.5939
Epoch 18/25
10000/10000 [==============================] - 1s 106us/step - loss: 8654.6864
Epoch 19/25
10000/10000 [==============================] - 1s 104us/step - loss: 8858.9382
Epoch 20/25
10000/10000 [==============================] - 1s 101us/step - loss: 8842.7146
Epoch 21/25
10000/10000 [==============================] - 1s 101us/step - loss: 8821.8297
Epoch 22/25
10000/10000 [==============================] - 1s 105us/step - loss: 8699.4650
Epoch 23/25
10000/10000 [==============================] - 1s 108us/step - loss: 8835.3129
Epoch 24/25
10000/10000 [==============================] - 1s 116us/step - loss: 8813.4877
Epoch 25/25
10000/10000 [==============================] - 1s 113us/step - loss: 8961.5180
*** Causal Estimate ***
## Target estimand
Estimand type: nonparametric-ate
### Estimand : 1
Estimand name: backdoor
Estimand expression:
d
─────(Expectation(y|W0,W1,W3,W2))
d[v₀]
Estimand assumption 1, Unconfoundedness: If U→{v0} and U→y then P(y|v0,W0,W1,W3,W2,U) = P(y|v0,W0,W1,W3,W2)
### Estimand : 2
Estimand name: iv
Estimand expression:
Expectation(Derivative(y, [Z0, Z1])*Derivative([v0], [Z0, Z1])**(-1))
Estimand assumption 1, As-if-random: If U→→y then ¬(U →→{Z0,Z1})
Estimand assumption 2, Exclusion: If we remove {Z0,Z1}→{v0}, then ¬({Z0,Z1}→y)
## Realized estimand
b: y~v0+W0+W1+W3+W2
## Estimate
Value: 4.3953657150268555
Metalearners¶
[21]:
data_experiment = dowhy.datasets.linear_dataset(10, num_common_causes=0, num_samples=10000,
num_instruments=2, num_effect_modifiers=4,
treatment_is_binary=True, outcome_is_binary=True)
# convert boolean values to {0,1} numeric
data_experiment['df'].v0 = data_experiment['df'].v0.astype(int)
data_experiment['df'].y = data_experiment['df'].y.astype(int)
print(data_experiment['df'])
model_experiment = CausalModel(data=data_experiment["df"],
treatment=data_experiment["treatment_name"], outcome=data_experiment["outcome_name"],
graph=data_experiment["gml_graph"])
identified_estimand_experiment = model_experiment.identify_effect(proceed_when_unidentifiable=True)
INFO:dowhy.causal_model:Model to find the causal effect of treatment ['v0'] on outcome ['y']
INFO:dowhy.causal_identifier:Common causes of treatment and outcome:['Unobserved Confounders']
WARNING:dowhy.causal_identifier:If this is observed data (not from a randomized experiment), there might always be missing confounders. Causal effect cannot be identified perfectly.
INFO:dowhy.causal_identifier:Continuing by ignoring these unobserved confounders because proceed_when_unidentifiable flag is True.
INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:['Z0', 'Z1']
X0 X1 X2 X3 Z0 Z1 v0 y
0 0.995605 2.013543 1.910889 0.501417 1.0 0.905808 1 0
1 -1.607226 -0.927225 -0.831809 -2.579346 0.0 0.003871 0 0
2 -1.542529 -2.111406 0.482695 0.669389 1.0 0.135426 1 0
3 -1.833528 0.065794 -1.490086 -1.895976 0.0 0.847716 1 0
4 -1.760129 -1.348228 2.084825 1.365359 0.0 0.248994 1 1
... ... ... ... ... ... ... .. ..
9995 -0.355503 -1.634357 0.915420 0.189763 0.0 0.956708 1 1
9996 -0.077529 -1.898474 -0.963058 0.126657 1.0 0.794665 1 1
9997 -0.157876 -0.812651 2.003876 -0.850003 1.0 0.708076 1 1
9998 -2.744425 -3.038899 -0.061868 0.242013 0.0 0.792547 1 1
9999 -1.211644 0.422355 0.996183 0.798486 0.0 0.996503 1 0
[10000 rows x 8 columns]
[ ]:
from sklearn.linear_model import LogisticRegressionCV
metalearner_estimate = model_experiment.estimate_effect(identified_estimand_experiment,
method_name="backdoor.econml.metalearners.TLearner",
target_units = lambda df: df["X0"]>1,
confidence_intervals=False,
method_params={"init_params":{
'models': LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto')
},
"fit_params":{}
})
print(metalearner_estimate)
Refuting the estimate¶
Random¶
[ ]:
res_random=model.refute_estimate(identified_estimand, dml_estimate, method_name="random_common_cause")
print(res_random)
Adding an unobserved common cause variable¶
[ ]:
res_unobserved=model.refute_estimate(identified_estimand, dml_estimate, method_name="add_unobserved_common_cause",
confounders_effect_on_treatment="linear", confounders_effect_on_outcome="linear",
effect_strength_on_treatment=0.01, effect_strength_on_outcome=0.02)
print(res_unobserved)
Replacing treatment with a random (placebo) variable¶
[ ]:
res_placebo=model.refute_estimate(identified_estimand, dml_estimate,
method_name="placebo_treatment_refuter", placebo_type="permute")
print(res_placebo)
Removing a random subset of the data¶
[ ]:
res_subset=model.refute_estimate(identified_estimand, dml_estimate,
method_name="data_subset_refuter", subset_fraction=0.8)
print(res_subset)
More refutation methods to come, especially specific to the CATE estimators.