Skip to content

Commit 1ec93b8

Browse files
drawlinsondrawlinson
and
drawlinson
authored
Expose individual interventional outcomes from RegressionEstimator _do operator (#1011) (#1018)
Explose individual interventional outcomes from RegressionEstimator _do operator. Original implementation of _do for RegressionEstimators calculated individual outcomes for all dataframe rows, and then returned the mean(). However, for many use-cases it is helpful to have the individual outcomes for further analysis. This commit moves the existing implementation (without changes) to a new function interventional_outcomes() which returns all outcomes. The implementation of _do() now only calls mean() on the returned values. The behaviour of _do() is unchanged. Signed-off-by: drawlinson <[email protected]> Co-authored-by: drawlinson <[email protected]>
1 parent 7841719 commit 1ec93b8

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

dowhy/causal_estimators/regression_estimator.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,17 @@ def _build_features(self, data_df: pd.DataFrame, treatment_values=None):
181181
features = sm.add_constant(features, has_constant="add") # to add an intercept term
182182
return features
183183

184-
def _do(self, data_df: pd.DataFrame, treatment_val):
184+
def interventional_outcomes(self, data_df: pd.DataFrame, treatment_val):
185+
"""
186+
Applies an intervention treatment_val to all rows in data_df, then uses self.model
187+
to predict outcomes. If data_df is None, will use self._data instead.
188+
If no model exists, one will be created. The outcomes of all samples are returned,
189+
allowing analysis of individual predictions in counterfactual treatment scenarios.
190+
:param data_df: data frame containing the data
191+
:param treatment_val: value for the treatment variable
192+
:returns: A list of outcome predictions.
193+
"""
194+
185195
if data_df is None:
186196
data_df = self._data
187197
if not self.model:
@@ -210,4 +220,8 @@ def _do(self, data_df: pd.DataFrame, treatment_val):
210220

211221
new_features = self._build_features(data_df, treatment_values=interventional_treatment_2d)
212222
interventional_outcomes = self.predict_fn(data_df, self.model, new_features)
223+
return interventional_outcomes
224+
225+
def _do(self, data_df: pd.DataFrame, treatment_val):
226+
interventional_outcomes = self.interventional_outcomes(data_df, treatment_val)
213227
return interventional_outcomes.mean()

0 commit comments

Comments
 (0)