Skip to content

What is the difference between ATE and the estimated effect of the refutation tests? #1351

@juandavidgutier

Description

@juandavidgutier

I am estimating the effect of SST12 (a binary variable) on excess (also a binary variable). I am using overlap_weights for training and test in my estimation, and I use the following line to estimate the Average Treatment Effect (ATE):
ate_SST12_current = estimate_SST12_current.ate(X_test) and the estimation of the ATE is 0.0741.

However, when I apply the refutation test with:

random_SST12_current = estimate_SST12_current.refute_estimate(method_name="random_common_cause", random_state=357, num_simulations=10, sample_weight=overlap_weights_test)

I obtain the following results:

Refute: Add a random common cause
Estimated effect: -0.0005013820368082984
New effect: -0.0006689520975190781
p-value: 0.39873337771733663

My question is: What is the difference between the ATE and the estimated effect of the refutation test? Shouldn't they be the same?

Version information:

  • DoWhy version = 0.11.1
  • EconML = 0.15.1

Here is an example of my example_data.csv

Here is my code:
`
import random
import os
import warnings
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from econml.dr import SparseLinearDRLearner, ForestDRLearner
from sklearn.preprocessing import PolynomialFeatures
from plotnine import ggplot, aes, geom_line, geom_ribbon, ggtitle, labs, geom_point, geom_hline, theme_linedraw, theme, element_rect, theme_light, element_line, element_text
from zepid.graphics import EffectMeasurePlot
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from scipy.stats import expon
import scipy.stats as stats
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import statsmodels.api as sm
from dowhy import CausalModel

#%%

Set seeds for reproducibility

def seed_everything(seed=357):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 357
seed_everything(seed)
warnings.filterwarnings('ignore')

#%%

Set display options for pandas

pd.set_option('display.float_format', lambda x: '{:.5f}'.format(x))

Create DataFrame for ATE results

data_ATE = pd.DataFrame(0.0, index=range(0, 8), columns=['ATE', '95% CI']).astype({'ATE': 'float64'})
data_ATE['95% CI'] = [((0.0, 0.0)) for _ in range(8)]
print(data_ATE)

#%%

Import data

data = pd.read_csv("D:/clases/UDES/articulo dengue/sst/data_final_07_23.csv", encoding='latin-1')

data = data[['SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30',
'consensus', 't2m', 'tp', 'MPI', 'pop_density', 'excess', 'Year', 'Month',
'DANE', 'DANE_year', 'DANE_period']]

data = data[data['Year'] >= 2013]

#%%

Convert columns to binary

columns_convert = ['SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30']
for col in columns_convert:
median = data[col].median()
data[col] = (data[col] > median).astype(int)

#%%

1. Label Encoding DANE

le = LabelEncoder()
data['DANE_labeled'] = le.fit_transform(data['DANE'])
scaler = MinMaxScaler()
data['DANE_normalized'] = scaler.fit_transform(
data[['DANE_labeled']]
)

2. Label Encoding DANE_year

le_year = LabelEncoder()
data['DANE_year_labeled'] = le_year.fit_transform(data['DANE_year'])
scaler_year = MinMaxScaler()
data['DANE_year_normalized'] = scaler_year.fit_transform(
data[['DANE_year_labeled']]
)

3. Label Encoding DANE_year_month

le_period = LabelEncoder()
data['DANE_period_labeled'] = le_period.fit_transform(data['DANE_period'])
scaler_period = MinMaxScaler()
data['DANE_period_normalized'] = scaler_period.fit_transform(
data[['DANE_period_labeled']]
)

#%%

transform year and month

data.Year = data.Year - 2007
data["sin_month"] = np.sin(2 * np.pi * data["Month"] / 12)
data["cos_month"] = np.cos(2 * np.pi * data["Month"] / 12)

#%%

SST in t+1

data['SST12_t1'] = data.groupby('DANE')['SST12'].shift(-1)
data['SST3_t1'] = data.groupby('DANE')['SST3'].shift(-1)
data['SST34_t1'] = data.groupby('DANE')['SST34'].shift(-1)
data['SST4_t1'] = data.groupby('DANE')['SST4'].shift(-1)

#%%

moving average variables

variables = ['SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30', 't2m', 'tp']

Tamaños de window deseados

windows = [2]

for var in variables:
for window in windows:
# new column
nueva_col = f'{var}_avg{window}'

    # moving average
    data[nueva_col] = data.groupby('DANE')[var].transform(
        lambda x: x.rolling(window=window, min_periods=1, closed='right').mean()
    )

print(data.columns)

#%%

SST12_current

data_SST12_current = data[['DANE_normalized', 'DANE_year_normalized', 'DANE_period_normalized', 'Year', 'sin_month', 'cos_month',
'SST12', 'SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP',
'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30',
't2m', 'tp', 'MPI', 'pop_density', 'consensus', 'excess', 'SST12_t1']]

data_SST12_current = data_SST12_current.dropna()

#%%

#Causal mechanism
model_SST12_current = CausalModel(
data = data_SST12_current,
treatment=['SST12'],
outcome=['excess'],
graph= """graph[directed 1
node[id "SST12" label "SST12"]
node[id "excess" label "excess"]
node[id "SST3" label "SST3"]
node[id "SST34" label "SST34"]
node[id "SST4" label "SST4"]
node[id "NATL" label "NATL"]
node[id "SATL" label "SATL"]
node[id "TROP" label "TROP"]
node[id "SOI" label "SOI"]
node[id "ESOI" label "ESOI"]
node[id "cpolr" label "cpolr"]
node[id "wpac850" label "wpac850"]
node[id "cpac850" label "cpac850"]
node[id "epac850" label "epac850"]
node[id "qbo_u30" label "qbo_u30"]
node[id "consensus" label "consensus"]
node[id "t2m" label "t2m"]
node[id "tp" label "tp"]
node[id "MPI" label "MPI"]
node[id "pop_density" label "pop_density"]
node[id "Year" label "Year"]
node[id "sin_month" label "sin_month"]
node[id "cos_month" label "cos_month"]
node[id "DANE" label "DANE"]
node[id "DANE_year" label "DANE_year"]
node[id "DANE_period" label "DANE_period"]

                edge[source "Year" target "SST12"]
                edge[source "Year" target "excess"]
                edge[source "Year" target "SST3"]
                edge[source "Year" target "SST34"]
                edge[source "Year" target "SST4"]
                edge[source "Year" target "NATL"]
                edge[source "Year" target "SATL"]
                edge[source "Year" target "TROP"]
                edge[source "Year" target "SOI"]
                edge[source "Year" target "ESOI"]
                edge[source "Year" target "cpolr"]
                edge[source "Year" target "wpac850"]
                edge[source "Year" target "cpac850"]
                edge[source "Year" target "epac850"]
                edge[source "Year" target "qbo_u30"]
                edge[source "Year" target "t2m"]
                edge[source "Year" target "tp"]
                
    	    edge[source "sin_month" target "SST12"]
                edge[source "sin_month" target "excess"]
                edge[source "sin_month" target "SST3"]
                edge[source "sin_month" target "SST34"]
                edge[source "sin_month" target "SST4"]
                edge[source "sin_month" target "NATL"]
                edge[source "sin_month" target "SATL"]
                edge[source "sin_month" target "TROP"]
                edge[source "sin_month" target "SOI"]
                edge[source "sin_month" target "ESOI"]
                edge[source "sin_month" target "cpolr"]
                edge[source "sin_month" target "wpac850"]
                edge[source "sin_month" target "cpac850"]
                edge[source "sin_month" target "epac850"]
                edge[source "sin_month" target "qbo_u30"]
                edge[source "sin_month" target "t2m"]
                edge[source "sin_month" target "tp"]
                
    		    edge[source "cos_month" target "SST12"]
                edge[source "cos_month" target "excess"]
                edge[source "cos_month" target "SST3"]
                edge[source "cos_month" target "SST34"]
                edge[source "cos_month" target "SST4"]
                edge[source "cos_month" target "NATL"]
                edge[source "cos_month" target "SATL"]
                edge[source "cos_month" target "TROP"]
                edge[source "cos_month" target "SOI"]
                edge[source "cos_month" target "ESOI"]
                edge[source "cos_month" target "cpolr"]
                edge[source "cos_month" target "wpac850"]
                edge[source "cos_month" target "cpac850"]
                edge[source "cos_month" target "epac850"]
                edge[source "cos_month" target "qbo_u30"]
                edge[source "cos_month" target "t2m"]
                edge[source "cos_month" target "tp"]
                
                edge[source "SST3" target "SST34"]
                edge[source "SST3" target "SST4"]
                edge[source "SST3" target "NATL"]
                edge[source "SST3" target "SATL"]
                edge[source "SST3" target "TROP"]
                edge[source "SST3" target "SOI"]
                edge[source "SST3" target "ESOI"]
                edge[source "SST3" target "cpolr"]
                edge[source "SST3" target "wpac850"]
                edge[source "SST3" target "cpac850"]
                edge[source "SST3" target "epac850"]
                edge[source "SST3" target "qbo_u30"]
                
                edge[source "SST34" target "SST4"]
                edge[source "SST34" target "NATL"]
                edge[source "SST34" target "SATL"]
                edge[source "SST34" target "TROP"]
                edge[source "SST34" target "SOI"]
                edge[source "SST34" target "ESOI"]
                edge[source "SST34" target "cpolr"]
                edge[source "SST34" target "wpac850"]
                edge[source "SST34" target "cpac850"]
                edge[source "SST34" target "epac850"]
                edge[source "SST34" target "qbo_u30"]
                
                edge[source "SST4" target "NATL"]
                edge[source "SST4" target "SATL"]
                edge[source "SST4" target "TROP"]
                edge[source "SST4" target "SOI"]
                edge[source "SST4" target "ESOI"]
                edge[source "SST4" target "cpolr"]
                edge[source "SST4" target "wpac850"]
                edge[source "SST4" target "cpac850"]
                edge[source "SST4" target "epac850"]
                edge[source "SST4" target "qbo_u30"]
                
                edge[source "NATL" target "SATL"]
                edge[source "NATL" target "TROP"]
                edge[source "NATL" target "SOI"]
                edge[source "NATL" target "ESOI"]
                edge[source "NATL" target "cpolr"]
                edge[source "NATL" target "wpac850"]
                edge[source "NATL" target "cpac850"]
                edge[source "NATL" target "epac850"]
                edge[source "NATL" target "qbo_u30"]
                
                edge[source "SATL" target "TROP"]
                edge[source "SATL" target "SOI"]
                edge[source "SATL" target "ESOI"]
                edge[source "SATL" target "cpolr"]
                edge[source "SATL" target "wpac850"]
                edge[source "SATL" target "cpac850"]
                edge[source "SATL" target "epac850"]
                edge[source "SATL" target "qbo_u30"]
                
                edge[source "TROP" target "SOI"]
                edge[source "TROP" target "ESOI"]
                edge[source "TROP" target "cpolr"]
                edge[source "TROP" target "wpac850"]
                edge[source "TROP" target "cpac850"]
                edge[source "TROP" target "epac850"]
                edge[source "TROP" target "qbo_u30"]
                
                edge[source "SOI" target "ESOI"]
                edge[source "SOI" target "cpolr"]
                edge[source "SOI" target "wpac850"]
                edge[source "SOI" target "cpac850"]
                edge[source "SOI" target "epac850"]
                edge[source "SOI" target "qbo_u30"]


                edge[source "ESOI" target "cpolr"]
                edge[source "ESOI" target "wpac850"]
                edge[source "ESOI" target "cpac850"]
                edge[source "ESOI" target "epac850"]
                edge[source "ESOI" target "qbo_u30"]
                
                edge[source "cpolr" target "wpac850"]
                edge[source "cpolr" target "cpac850"]
                edge[source "cpolr" target "epac850"]
                edge[source "cpolr" target "qbo_u30"]
                
                edge[source "wpac850" target "cpac850"]
                edge[source "wpac850" target "epac850"]
                edge[source "wpac850" target "qbo_u30"]
                
                edge[source "cpac850" target "epac850"]
                edge[source "cpac850" target "qbo_u30"]
                
                edge[source "epac850" target "qbo_u30"]
                
                edge[source "SST3" target "SST12"]
                edge[source "SST3" target "excess"]
                edge[source "SST34" target "SST12"]
                edge[source "SST34" target "excess"]
                edge[source "SST4" target "SST12"]
                edge[source "SST4" target "excess"]
                edge[source "NATL" target "SST12"]
                edge[source "NATL" target "excess"]
                edge[source "SATL" target "SST12"]
                edge[source "SATL" target "excess"]
                edge[source "TROP" target "SST12"]
                edge[source "TROP" target "excess"]
                edge[source "SOI" target "SST12"]
                edge[source "SOI" target "excess"]
                edge[source "ESOI" target "SST12"]
                edge[source "ESOI" target "excess"]
                edge[source "cpolr" target "SST12"]
                edge[source "cpolr" target "excess"]
                edge[source "wpac850" target "SST12"]
                edge[source "wpac850" target "excess"]
                edge[source "cpac850" target "SST12"]
                edge[source "cpac850" target "excess"]
                edge[source "epac850" target "SST12"]
                edge[source "epac850" target "excess"]
                edge[source "qbo_u30" target "SST12"]
                edge[source "qbo_u30" target "excess"]
                
                edge[source "SST3" target "consensus"]
                edge[source "SST3" target "t2m"]
                edge[source "SST3" target "tp"]
                
                edge[source "SST34" target "consensus"]
                edge[source "SST34" target "t2m"]
                edge[source "SST34" target "tp"]
                
                edge[source "SST4" target "consensus"]
                edge[source "SST4" target "t2m"]
                edge[source "SST4" target "tp"]
                
                edge[source "NATL" target "consensus"]
                edge[source "NATL" target "t2m"]
                edge[source "NATL" target "tp"]
                
                edge[source "SATL" target "consensus"]
                edge[source "SATL" target "t2m"]
                edge[source "SATL" target "tp"]
                
                edge[source "TROP" target "consensus"]
                edge[source "TROP" target "t2m"]
                edge[source "TROP" target "tp"]
                
                edge[source "SOI" target "consensus"]
                edge[source "SOI" target "t2m"]
                edge[source "SOI" target "tp"]
                
                edge[source "ESOI" target "consensus"]
                edge[source "ESOI" target "t2m"]
                edge[source "ESOI" target "tp"]
                
                edge[source "cpolr" target "consensus"]
                edge[source "cpolr" target "t2m"]
                edge[source "cpolr" target "tp"]
                
                edge[source "wpac850" target "consensus"]
                edge[source "wpac850" target "t2m"]
                edge[source "wpac850" target "tp"]
                
                edge[source "cpac850" target "consensus"]
                edge[source "cpac850" target "t2m"]
                edge[source "cpac850" target "tp"]
                
                edge[source "epac850" target "consensus"]
                edge[source "epac850" target "t2m"]
                edge[source "epac850" target "tp"]
                
                edge[source "qbo_u30" target "consensus"]
                edge[source "qbo_u30" target "t2m"]
                edge[source "qbo_u30" target "tp"]
                
                edge[source "SST12" target "consensus"]
                edge[source "SST12" target "t2m"]
                edge[source "SST12" target "tp"]
                
                edge[source "consensus" target "t2m"]
                edge[source "consensus" target "tp"]
                
                edge[source "t2m" target "excess"]
                edge[source "tp" target "excess"]                    
                
                edge[source "t2m" target "pop_density"]
                edge[source "tp" target "pop_density"]
                
                edge[source "pop_density" target "excess"]
                
                edge[source "pop_density" target "MPI"]

                edge[source "MPI" target "excess"]
                
                edge[source "DANE" target "excess"]
                edge[source "DANE_year" target "excess"]
                edge[source "DANE_period" target "excess"]
        
                ]"""
                )

#%%

Identifying effects

identified_estimand_SST12_current = model_SST12_current.identify_effect(proceed_when_unidentifiable=None)
print(identified_estimand_SST12_current)

#%%

Model

Y = data_SST12_current['excess'].to_numpy()
T = data_SST12_current['SST12'].to_numpy()
W = data_SST12_current[['SST3', 'SST34', 'SST4', 'NATL', 'SATL', 'TROP', 'SOI', 'ESOI', 'cpolr', 'wpac850', 'cpac850', 'epac850', 'qbo_u30',
'sin_month', 'cos_month', 'Year']].to_numpy()
X = data_SST12_current[['MPI', 'DANE_normalized', 'DANE_year_normalized', 'DANE_period_normalized', 'SST12_t1']].to_numpy()

Split data

X_train, X_test, T_train, T_test, Y_train, Y_test, W_train, W_test = train_test_split(
X, T, Y, W, test_size=0.2, random_state=357, stratify=T)

Calculate propensity scores and overlap weights

logit_model = LogisticRegressionCV(
penalty='l2', cv=5, random_state=357,
max_iter=30000, solver='liblinear', scoring='neg_log_loss'
)
logit_model.fit(W_train, T_train)
propensity_scores_train = logit_model.predict_proba(W_train)[:, 1]
overlap_weights_train = T_train * (1 - propensity_scores_train) + (1 - T_train) * propensity_scores_train
overlap_weights_train = np.clip(overlap_weights_train, 0.01, 100)

propensity_scores_test = logit_model.predict_proba(W_test)[:, 1]
overlap_weights_test = T_test * (1 - propensity_scores_test) + (1 - T_test) * propensity_scores_test
overlap_weights_test = np.clip(overlap_weights_test, 0.01, 100)

#%%

Estimation of the effect

estimate_SST12_current = SparseLinearDRLearner(max_iter=30000, cv=5, random_state=357)

estimate_SST12_current = estimate_SST12_current.dowhy

Fit the model with corrected overlap weights

estimate_SST12_current.fit(
Y=Y_train,
T=T_train,
X=X_train,
W=W_train,
inference='auto',
sample_weight=overlap_weights_train
)

Predict effect for each sample

te_pred = estimate_SST12_current.effect(X_test)

Calculate ATE

ate_SST12_current = estimate_SST12_current.ate(X_test)
print(ate_SST12_current)

Confidence interval of ATE

ci_SST12_current = estimate_SST12_current.ate_interval(X_test)
print(ci_SST12_current)

Set values in df_ATE

data_ATE.at[0, 'ATE'] = ate_SST12_current
data_ATE.at[0, '95% CI'] = ci_SST12_current
print(data_ATE)

data_ATE.to_csv('D:/clases/UDES/articulo dengue/sst/ci/data_ATE.csv', index=False)

#%%

CATE MPI

Extract MPI for marginal effect

MPI_train = X_train[:, 0]
MPI_test = X_test[:, 0]

Grid for MPI

min_MPI = MPI_train.min()
max_MPI = MPI_train.max()
delta = (max_MPI - min_MPI) / 100
MPI_grid = np.arange(min_MPI, max_MPI + delta - 0.001, delta)

Means of other variables in X

DANE_encoded_mean = np.mean(X_train[:, 1])
DANE_year_encoded_mean = np.mean(X_train[:, 2])
DANE_period_encoded_mean = np.mean(X_train[:, 3])
SST12_t1_mean = np.mean(X_train[:, 4])

Matrix of X

X_test_grid = np.column_stack([
MPI_grid,
np.full_like(MPI_grid, DANE_encoded_mean),
np.full_like(MPI_grid, DANE_year_encoded_mean),
np.full_like(MPI_grid, DANE_period_encoded_mean),
np.full_like(MPI_grid, SST12_t1_mean)
])

Conditional effect

treatment_cont_marg = estimate_SST12_current.effect(X_test_grid)
hte_lower2_cons, hte_upper2_cons = estimate_SST12_current.effect_interval(X_test_grid)

Reshape para plotting

MPI_grid_plot = MPI_grid
treatment_cont_marg_plot = treatment_cont_marg

DataFrame for plotting

plot_data = pd.DataFrame({
'X_test': MPI_grid_plot,
'treatment_cont_marg': treatment_cont_marg_plot,
'hte_lower2_cons': hte_lower2_cons,
'hte_upper2_cons': hte_upper2_cons
})

Figure CATE -

(
ggplot(plot_data)
+ aes(x='X_test', y='treatment_cont_marg')
+ geom_line(color='blue', size=1)
+ geom_ribbon(aes(ymin='hte_lower2_cons', ymax='hte_upper2_cons'), alpha=0.2, fill='blue')
+ labs(x='MPI %', y='Effect of SST12 on excess dengue cases',
title='a')
+ geom_hline(yintercept=0, color="red", linetype="dashed", size=0.8)
+ theme(plot_title=element_text(hjust=0.5, size=12),
axis_title_x=element_text(size=10),
axis_title_y=element_text(size=10))
)

#%%

#Refute tests
#with random common cause
random_SST12_current = estimate_SST12_current.refute_estimate(method_name="random_common_cause", random_state=357, num_simulations=10,
sample_weight=overlap_weights_test)
print(random_SST12_current)
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requestedstale

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions