-
Notifications
You must be signed in to change notification settings - Fork 992
Description
I am estimating the effect of SST12 (a binary variable) on excess (also a binary variable). I am using overlap_weights for training and test in my estimation, and I use the following line to estimate the Average Treatment Effect (ATE):
ate_SST12_current = estimate_SST12_current.ate(X_test) and the estimation of the ATE is 0.0741.
However, when I apply the refutation test with:
random_SST12_current = estimate_SST12_current.refute_estimate(method_name="random_common_cause", random_state=357, num_simulations=10, sample_weight=overlap_weights_test)
I obtain the following results:
Refute: Add a random common cause
Estimated effect: -0.0005013820368082984
New effect: -0.0006689520975190781
p-value: 0.39873337771733663
My question is: What is the difference between the ATE and the estimated effect of the refutation test? Shouldn't they be the same?
Version information:
- DoWhy version = 0.11.1
- EconML = 0.15.1
Here is an example of my example_data.csv
Here is my code:
`
import random
import os
import warnings
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from econml.dr import SparseLinearDRLearner, ForestDRLearner
from sklearn.preprocessing import PolynomialFeatures
from plotnine import ggplot, aes, geom_line, geom_ribbon, ggtitle, labs, geom_point, geom_hline, theme_linedraw, theme, element_rect, theme_light, element_line, element_text
from zepid.graphics import EffectMeasurePlot
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from scipy.stats import expon
import scipy.stats as stats
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import statsmodels.api as sm
from dowhy import CausalModel
#%%
Set seeds for reproducibility
def seed_everything(seed=357):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
seed = 357
seed_everything(seed)
warnings.filterwarnings('ignore')
#%%
Set display options for pandas
pd.set_option('display.float_format', lambda x: '{:.5f}'.format(x))
Create DataFrame for ATE results
data_ATE = pd.DataFrame(0.0, index=range(0, 8), columns=['ATE', '95% CI']).astype({'ATE': 'float64'})
data_ATE['95% CI'] = [((0.0, 0.0)) for _ in range(8)]
print(data_ATE)
#%%
Import data
data = pd.read_csv("D:/clases/UDES/articulo dengue/sst/data_final_07_23.csv", encoding='latin-1')
data = data[['SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30',
'consensus', 't2m', 'tp', 'MPI', 'pop_density', 'excess', 'Year', 'Month',
'DANE', 'DANE_year', 'DANE_period']]
data = data[data['Year'] >= 2013]
#%%
Convert columns to binary
columns_convert = ['SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30']
for col in columns_convert:
median = data[col].median()
data[col] = (data[col] > median).astype(int)
#%%
1. Label Encoding DANE
le = LabelEncoder()
data['DANE_labeled'] = le.fit_transform(data['DANE'])
scaler = MinMaxScaler()
data['DANE_normalized'] = scaler.fit_transform(
data[['DANE_labeled']]
)
2. Label Encoding DANE_year
le_year = LabelEncoder()
data['DANE_year_labeled'] = le_year.fit_transform(data['DANE_year'])
scaler_year = MinMaxScaler()
data['DANE_year_normalized'] = scaler_year.fit_transform(
data[['DANE_year_labeled']]
)
3. Label Encoding DANE_year_month
le_period = LabelEncoder()
data['DANE_period_labeled'] = le_period.fit_transform(data['DANE_period'])
scaler_period = MinMaxScaler()
data['DANE_period_normalized'] = scaler_period.fit_transform(
data[['DANE_period_labeled']]
)
#%%
transform year and month
data.Year = data.Year - 2007
data["sin_month"] = np.sin(2 * np.pi * data["Month"] / 12)
data["cos_month"] = np.cos(2 * np.pi * data["Month"] / 12)
#%%
SST in t+1
data['SST12_t1'] = data.groupby('DANE')['SST12'].shift(-1)
data['SST3_t1'] = data.groupby('DANE')['SST3'].shift(-1)
data['SST34_t1'] = data.groupby('DANE')['SST34'].shift(-1)
data['SST4_t1'] = data.groupby('DANE')['SST4'].shift(-1)
#%%
moving average variables
variables = ['SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30', 't2m', 'tp']
Tamaños de window deseados
windows = [2]
for var in variables:
for window in windows:
# new column
nueva_col = f'{var}_avg{window}'
# moving average
data[nueva_col] = data.groupby('DANE')[var].transform(
lambda x: x.rolling(window=window, min_periods=1, closed='right').mean()
)
print(data.columns)
#%%
SST12_current
data_SST12_current = data[['DANE_normalized', 'DANE_year_normalized', 'DANE_period_normalized', 'Year', 'sin_month', 'cos_month',
'SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP',
'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30',
't2m', 'tp', 'MPI', 'pop_density', 'consensus', 'excess', 'SST12_t1']]
data_SST12_current = data_SST12_current.dropna()
#%%
#Causal mechanism
model_SST12_current = CausalModel(
data = data_SST12_current,
treatment=['SST12'],
outcome=['excess'],
graph= """graph[directed 1
node[id "SST12" label "SST12"]
node[id "excess" label "excess"]
node[id "SST3" label "SST3"]
node[id "SST34" label "SST34"]
node[id "SST4" label "SST4"]
node[id "NATL" label "NATL"]
node[id "SATL" label "SATL"]
node[id "TROP" label "TROP"]
node[id "SOI" label "SOI"]
node[id "ESOI" label "ESOI"]
node[id "cpolr" label "cpolr"]
node[id "wpac850" label "wpac850"]
node[id "cpac850" label "cpac850"]
node[id "epac850" label "epac850"]
node[id "qbo_u30" label "qbo_u30"]
node[id "consensus" label "consensus"]
node[id "t2m" label "t2m"]
node[id "tp" label "tp"]
node[id "MPI" label "MPI"]
node[id "pop_density" label "pop_density"]
node[id "Year" label "Year"]
node[id "sin_month" label "sin_month"]
node[id "cos_month" label "cos_month"]
node[id "DANE" label "DANE"]
node[id "DANE_year" label "DANE_year"]
node[id "DANE_period" label "DANE_period"]
edge[source "Year" target "SST12"]
edge[source "Year" target "excess"]
edge[source "Year" target "SST3"]
edge[source "Year" target "SST34"]
edge[source "Year" target "SST4"]
edge[source "Year" target "NATL"]
edge[source "Year" target "SATL"]
edge[source "Year" target "TROP"]
edge[source "Year" target "SOI"]
edge[source "Year" target "ESOI"]
edge[source "Year" target "cpolr"]
edge[source "Year" target "wpac850"]
edge[source "Year" target "cpac850"]
edge[source "Year" target "epac850"]
edge[source "Year" target "qbo_u30"]
edge[source "Year" target "t2m"]
edge[source "Year" target "tp"]
edge[source "sin_month" target "SST12"]
edge[source "sin_month" target "excess"]
edge[source "sin_month" target "SST3"]
edge[source "sin_month" target "SST34"]
edge[source "sin_month" target "SST4"]
edge[source "sin_month" target "NATL"]
edge[source "sin_month" target "SATL"]
edge[source "sin_month" target "TROP"]
edge[source "sin_month" target "SOI"]
edge[source "sin_month" target "ESOI"]
edge[source "sin_month" target "cpolr"]
edge[source "sin_month" target "wpac850"]
edge[source "sin_month" target "cpac850"]
edge[source "sin_month" target "epac850"]
edge[source "sin_month" target "qbo_u30"]
edge[source "sin_month" target "t2m"]
edge[source "sin_month" target "tp"]
edge[source "cos_month" target "SST12"]
edge[source "cos_month" target "excess"]
edge[source "cos_month" target "SST3"]
edge[source "cos_month" target "SST34"]
edge[source "cos_month" target "SST4"]
edge[source "cos_month" target "NATL"]
edge[source "cos_month" target "SATL"]
edge[source "cos_month" target "TROP"]
edge[source "cos_month" target "SOI"]
edge[source "cos_month" target "ESOI"]
edge[source "cos_month" target "cpolr"]
edge[source "cos_month" target "wpac850"]
edge[source "cos_month" target "cpac850"]
edge[source "cos_month" target "epac850"]
edge[source "cos_month" target "qbo_u30"]
edge[source "cos_month" target "t2m"]
edge[source "cos_month" target "tp"]
edge[source "SST3" target "SST34"]
edge[source "SST3" target "SST4"]
edge[source "SST3" target "NATL"]
edge[source "SST3" target "SATL"]
edge[source "SST3" target "TROP"]
edge[source "SST3" target "SOI"]
edge[source "SST3" target "ESOI"]
edge[source "SST3" target "cpolr"]
edge[source "SST3" target "wpac850"]
edge[source "SST3" target "cpac850"]
edge[source "SST3" target "epac850"]
edge[source "SST3" target "qbo_u30"]
edge[source "SST34" target "SST4"]
edge[source "SST34" target "NATL"]
edge[source "SST34" target "SATL"]
edge[source "SST34" target "TROP"]
edge[source "SST34" target "SOI"]
edge[source "SST34" target "ESOI"]
edge[source "SST34" target "cpolr"]
edge[source "SST34" target "wpac850"]
edge[source "SST34" target "cpac850"]
edge[source "SST34" target "epac850"]
edge[source "SST34" target "qbo_u30"]
edge[source "SST4" target "NATL"]
edge[source "SST4" target "SATL"]
edge[source "SST4" target "TROP"]
edge[source "SST4" target "SOI"]
edge[source "SST4" target "ESOI"]
edge[source "SST4" target "cpolr"]
edge[source "SST4" target "wpac850"]
edge[source "SST4" target "cpac850"]
edge[source "SST4" target "epac850"]
edge[source "SST4" target "qbo_u30"]
edge[source "NATL" target "SATL"]
edge[source "NATL" target "TROP"]
edge[source "NATL" target "SOI"]
edge[source "NATL" target "ESOI"]
edge[source "NATL" target "cpolr"]
edge[source "NATL" target "wpac850"]
edge[source "NATL" target "cpac850"]
edge[source "NATL" target "epac850"]
edge[source "NATL" target "qbo_u30"]
edge[source "SATL" target "TROP"]
edge[source "SATL" target "SOI"]
edge[source "SATL" target "ESOI"]
edge[source "SATL" target "cpolr"]
edge[source "SATL" target "wpac850"]
edge[source "SATL" target "cpac850"]
edge[source "SATL" target "epac850"]
edge[source "SATL" target "qbo_u30"]
edge[source "TROP" target "SOI"]
edge[source "TROP" target "ESOI"]
edge[source "TROP" target "cpolr"]
edge[source "TROP" target "wpac850"]
edge[source "TROP" target "cpac850"]
edge[source "TROP" target "epac850"]
edge[source "TROP" target "qbo_u30"]
edge[source "SOI" target "ESOI"]
edge[source "SOI" target "cpolr"]
edge[source "SOI" target "wpac850"]
edge[source "SOI" target "cpac850"]
edge[source "SOI" target "epac850"]
edge[source "SOI" target "qbo_u30"]
edge[source "ESOI" target "cpolr"]
edge[source "ESOI" target "wpac850"]
edge[source "ESOI" target "cpac850"]
edge[source "ESOI" target "epac850"]
edge[source "ESOI" target "qbo_u30"]
edge[source "cpolr" target "wpac850"]
edge[source "cpolr" target "cpac850"]
edge[source "cpolr" target "epac850"]
edge[source "cpolr" target "qbo_u30"]
edge[source "wpac850" target "cpac850"]
edge[source "wpac850" target "epac850"]
edge[source "wpac850" target "qbo_u30"]
edge[source "cpac850" target "epac850"]
edge[source "cpac850" target "qbo_u30"]
edge[source "epac850" target "qbo_u30"]
edge[source "SST3" target "SST12"]
edge[source "SST3" target "excess"]
edge[source "SST34" target "SST12"]
edge[source "SST34" target "excess"]
edge[source "SST4" target "SST12"]
edge[source "SST4" target "excess"]
edge[source "NATL" target "SST12"]
edge[source "NATL" target "excess"]
edge[source "SATL" target "SST12"]
edge[source "SATL" target "excess"]
edge[source "TROP" target "SST12"]
edge[source "TROP" target "excess"]
edge[source "SOI" target "SST12"]
edge[source "SOI" target "excess"]
edge[source "ESOI" target "SST12"]
edge[source "ESOI" target "excess"]
edge[source "cpolr" target "SST12"]
edge[source "cpolr" target "excess"]
edge[source "wpac850" target "SST12"]
edge[source "wpac850" target "excess"]
edge[source "cpac850" target "SST12"]
edge[source "cpac850" target "excess"]
edge[source "epac850" target "SST12"]
edge[source "epac850" target "excess"]
edge[source "qbo_u30" target "SST12"]
edge[source "qbo_u30" target "excess"]
edge[source "SST3" target "consensus"]
edge[source "SST3" target "t2m"]
edge[source "SST3" target "tp"]
edge[source "SST34" target "consensus"]
edge[source "SST34" target "t2m"]
edge[source "SST34" target "tp"]
edge[source "SST4" target "consensus"]
edge[source "SST4" target "t2m"]
edge[source "SST4" target "tp"]
edge[source "NATL" target "consensus"]
edge[source "NATL" target "t2m"]
edge[source "NATL" target "tp"]
edge[source "SATL" target "consensus"]
edge[source "SATL" target "t2m"]
edge[source "SATL" target "tp"]
edge[source "TROP" target "consensus"]
edge[source "TROP" target "t2m"]
edge[source "TROP" target "tp"]
edge[source "SOI" target "consensus"]
edge[source "SOI" target "t2m"]
edge[source "SOI" target "tp"]
edge[source "ESOI" target "consensus"]
edge[source "ESOI" target "t2m"]
edge[source "ESOI" target "tp"]
edge[source "cpolr" target "consensus"]
edge[source "cpolr" target "t2m"]
edge[source "cpolr" target "tp"]
edge[source "wpac850" target "consensus"]
edge[source "wpac850" target "t2m"]
edge[source "wpac850" target "tp"]
edge[source "cpac850" target "consensus"]
edge[source "cpac850" target "t2m"]
edge[source "cpac850" target "tp"]
edge[source "epac850" target "consensus"]
edge[source "epac850" target "t2m"]
edge[source "epac850" target "tp"]
edge[source "qbo_u30" target "consensus"]
edge[source "qbo_u30" target "t2m"]
edge[source "qbo_u30" target "tp"]
edge[source "SST12" target "consensus"]
edge[source "SST12" target "t2m"]
edge[source "SST12" target "tp"]
edge[source "consensus" target "t2m"]
edge[source "consensus" target "tp"]
edge[source "t2m" target "excess"]
edge[source "tp" target "excess"]
edge[source "t2m" target "pop_density"]
edge[source "tp" target "pop_density"]
edge[source "pop_density" target "excess"]
edge[source "pop_density" target "MPI"]
edge[source "MPI" target "excess"]
edge[source "DANE" target "excess"]
edge[source "DANE_year" target "excess"]
edge[source "DANE_period" target "excess"]
]"""
)
#%%
Identifying effects
identified_estimand_SST12_current = model_SST12_current.identify_effect(proceed_when_unidentifiable=None)
print(identified_estimand_SST12_current)
#%%
Model
Y = data_SST12_current['excess'].to_numpy()
T = data_SST12_current['SST12'].to_numpy()
W = data_SST12_current[['SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30',
'sin_month', 'cos_month', 'Year']].to_numpy()
X = data_SST12_current[['MPI', 'DANE_normalized', 'DANE_year_normalized', 'DANE_period_normalized', 'SST12_t1']].to_numpy()
Split data
X_train, X_test, T_train, T_test, Y_train, Y_test, W_train, W_test = train_test_split(
X, T, Y, W, test_size=0.2, random_state=357, stratify=T)
Calculate propensity scores and overlap weights
logit_model = LogisticRegressionCV(
penalty='l2', cv=5, random_state=357,
max_iter=30000, solver='liblinear', scoring='neg_log_loss'
)
logit_model.fit(W_train, T_train)
propensity_scores_train = logit_model.predict_proba(W_train)[:, 1]
overlap_weights_train = T_train * (1 - propensity_scores_train) + (1 - T_train) * propensity_scores_train
overlap_weights_train = np.clip(overlap_weights_train, 0.01, 100)
propensity_scores_test = logit_model.predict_proba(W_test)[:, 1]
overlap_weights_test = T_test * (1 - propensity_scores_test) + (1 - T_test) * propensity_scores_test
overlap_weights_test = np.clip(overlap_weights_test, 0.01, 100)
#%%
Estimation of the effect
estimate_SST12_current = SparseLinearDRLearner(max_iter=30000, cv=5, random_state=357)
estimate_SST12_current = estimate_SST12_current.dowhy
Fit the model with corrected overlap weights
estimate_SST12_current.fit(
Y=Y_train,
T=T_train,
X=X_train,
W=W_train,
inference='auto',
sample_weight=overlap_weights_train
)
Predict effect for each sample
te_pred = estimate_SST12_current.effect(X_test)
Calculate ATE
ate_SST12_current = estimate_SST12_current.ate(X_test)
print(ate_SST12_current)
Confidence interval of ATE
ci_SST12_current = estimate_SST12_current.ate_interval(X_test)
print(ci_SST12_current)
Set values in df_ATE
data_ATE.at[0, 'ATE'] = ate_SST12_current
data_ATE.at[0, '95% CI'] = ci_SST12_current
print(data_ATE)
data_ATE.to_csv('D:/clases/UDES/articulo dengue/sst/ci/data_ATE.csv', index=False)
#%%
CATE MPI
Extract MPI for marginal effect
MPI_train = X_train[:, 0]
MPI_test = X_test[:, 0]
Grid for MPI
min_MPI = MPI_train.min()
max_MPI = MPI_train.max()
delta = (max_MPI - min_MPI) / 100
MPI_grid = np.arange(min_MPI, max_MPI + delta - 0.001, delta)
Means of other variables in X
DANE_encoded_mean = np.mean(X_train[:, 1])
DANE_year_encoded_mean = np.mean(X_train[:, 2])
DANE_period_encoded_mean = np.mean(X_train[:, 3])
SST12_t1_mean = np.mean(X_train[:, 4])
Matrix of X
X_test_grid = np.column_stack([
MPI_grid,
np.full_like(MPI_grid, DANE_encoded_mean),
np.full_like(MPI_grid, DANE_year_encoded_mean),
np.full_like(MPI_grid, DANE_period_encoded_mean),
np.full_like(MPI_grid, SST12_t1_mean)
])
Conditional effect
treatment_cont_marg = estimate_SST12_current.effect(X_test_grid)
hte_lower2_cons, hte_upper2_cons = estimate_SST12_current.effect_interval(X_test_grid)
Reshape para plotting
MPI_grid_plot = MPI_grid
treatment_cont_marg_plot = treatment_cont_marg
DataFrame for plotting
plot_data = pd.DataFrame({
'X_test': MPI_grid_plot,
'treatment_cont_marg': treatment_cont_marg_plot,
'hte_lower2_cons': hte_lower2_cons,
'hte_upper2_cons': hte_upper2_cons
})
Figure CATE -
(
ggplot(plot_data)
+ aes(x='X_test', y='treatment_cont_marg')
+ geom_line(color='blue', size=1)
+ geom_ribbon(aes(ymin='hte_lower2_cons', ymax='hte_upper2_cons'), alpha=0.2, fill='blue')
+ labs(x='MPI %', y='Effect of SST12 on excess dengue cases',
title='a')
+ geom_hline(yintercept=0, color="red", linetype="dashed", size=0.8)
+ theme(plot_title=element_text(hjust=0.5, size=12),
axis_title_x=element_text(size=10),
axis_title_y=element_text(size=10))
)
#%%
#Refute tests
#with random common cause
random_SST12_current = estimate_SST12_current.refute_estimate(method_name="random_common_cause", random_state=357, num_simulations=10,
sample_weight=overlap_weights_test)
print(random_SST12_current)
`