Skip to content

Commit 77906cd

Browse files
authored
Better docs: updated the method signature and docstrings for estimator classes (#389)
* updated the method signature and docstrings for estimator classes * bug fixes * updated the args passing * fixed the bug in refutation test calls for num_ci_simulations * used std init args * updated num simulations to pass weighting test
1 parent 5ba0a4a commit 77906cd

27 files changed

+411
-195
lines changed

docs/source/conf.py

+4
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,7 @@
197197

198198
# If true, `todo` and `todoList` produce output, else they produce nothing.
199199
todo_include_todos = True
200+
201+
# init docstrings should also be included in class
202+
autoclass_content = "both"
203+

docs/source/dowhy.causal_refuters.rst

+8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ dowhy.causal\_refuters.dummy\_outcome\_refuter module
3636
:undoc-members:
3737
:show-inheritance:
3838

39+
dowhy.causal\_refuters.graph\_refuter module
40+
--------------------------------------------
41+
42+
.. automodule:: dowhy.causal_refuters.graph_refuter
43+
:members:
44+
:undoc-members:
45+
:show-inheritance:
46+
3947
dowhy.causal\_refuters.placebo\_treatment\_refuter module
4048
---------------------------------------------------------
4149

docs/source/dowhy.utils.rst

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ dowhy.utils.api module
1212
:undoc-members:
1313
:show-inheritance:
1414

15+
dowhy.utils.cit module
16+
----------------------
17+
18+
.. automodule:: dowhy.utils.cit
19+
:members:
20+
:undoc-members:
21+
:show-inheritance:
22+
1523
dowhy.utils.cli\_helpers module
1624
-------------------------------
1725

dowhy/causal_estimator.py

+43-33
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,21 @@ class CausalEstimator:
3838

3939
DEFAULT_INTERPRET_METHOD = ["textual_effect_interpreter"]
4040

41+
# std args to be removed from locals() before being passed to args_dict
42+
_STD_INIT_ARGS = ('self', '__class__', 'args', 'kwargs')
43+
4144
def __init__(self, data, identified_estimand, treatment, outcome,
4245
control_value=0, treatment_value=1,
4346
test_significance=False, evaluate_effect_strength=False,
4447
confidence_intervals=False,
4548
target_units=None, effect_modifiers=None,
46-
params=None):
49+
num_null_simulations=DEFAULT_NUMBER_OF_SIMULATIONS_STAT_TEST,
50+
num_simulations=DEFAULT_NUMBER_OF_SIMULATIONS_CI,
51+
sample_size_fraction=DEFAULT_SAMPLE_SIZE_FRACTION,
52+
confidence_level=DEFAULT_CONFIDENCE_LEVEL,
53+
need_conditional_estimates='auto',
54+
num_quantiles_to_discretize_cont_cols=NUM_QUANTILES_TO_DISCRETIZE_CONT_COLS,
55+
**kwargs):
4756
"""Initializes an estimator with data and names of relevant variables.
4857
4958
This method is called from the constructors of its child classes.
@@ -55,19 +64,29 @@ def __init__(self, data, identified_estimand, treatment, outcome,
5564
:param outcome: name of the outcome variable
5665
:param control_value: Value of the treatment in the control group, for effect estimation. If treatment is multi-variate, this can be a list.
5766
:param treatment_value: Value of the treatment in the treated group, for effect estimation. If treatment is multi-variate, this can be a list.
58-
:param test_significance: Binary flag or a string indicating whether to test significance and by which method. All estimators support test_significance="bootstrap" that estimates a p-value for the obtained estimate using the bootstrap method. Individual estimators can override this to support custom testing methods. The bootstrap method supports an optional parameter, num_null_simulations that can be specified through the params dictionary. If False, no testing is done. If True, significance of the estimate is tested using the custom method if available, otherwise by bootstrap.
67+
:param test_significance: Binary flag or a string indicating whether to test significance and by which method. All estimators support test_significance="bootstrap" that estimates a p-value for the obtained estimate using the bootstrap method. Individual estimators can override this to support custom testing methods. The bootstrap method supports an optional parameter, num_null_simulations. If False, no testing is done. If True, significance of the estimate is tested using the custom method if available, otherwise by bootstrap.
5968
:param evaluate_effect_strength: (Experimental) whether to evaluate the strength of effect
6069
:param confidence_intervals: Binary flag or a string indicating whether the confidence intervals should be computed and which method should be used. All methods support estimation of confidence intervals using the bootstrap method by using the parameter confidence_intervals="bootstrap". The bootstrap method takes in two arguments (num_simulations and sample_size_fraction) that can be optionally specified in the params dictionary. Estimators may also override this to implement their own confidence interval method. If this parameter is False, no confidence intervals are computed. If True, confidence intervals are computed by the estimator's specific method if available, otherwise through bootstrap.
6170
:param target_units: The units for which the treatment effect should be estimated. This can be a string for common specifications of target units (namely, "ate", "att" and "atc"). It can also be a lambda function that can be used as an index for the data (pandas DataFrame). Alternatively, it can be a new DataFrame that contains values of the effect_modifiers and effect will be estimated only for this new data.
62-
:param effect_modifiers: Variables on which to compute separate effects, or return a heterogeneous effect function. Not all methods support this currently.
63-
:param params: (optional) Additional method parameters
64-
num_null_simulations: The number of simulations for testing the statistical significance of the estimator
65-
num_simulations: The number of simulations for finding the confidence interval (and/or standard error) for a estimate
66-
sample_size_fraction: The size of the sample for the bootstrap estimator
67-
confidence_level: The confidence level of the confidence interval estimate
68-
num_quantiles_to_discretize_cont_cols: The number of quantiles into which a numeric effect modifier is split, to enable estimation of conditional treatment effect over it.
71+
:param effect_modifiers: Variables on which to compute separate
72+
effects, or return a heterogeneous effect function. Not all
73+
methods support this currently.
74+
:param num_null_simulations: The number of simulations for testing the
75+
statistical significance of the estimator
76+
:param num_simulations: The number of simulations for finding the
77+
confidence interval (and/or standard error) for a estimate
78+
:param sample_size_fraction: The size of the sample for the bootstrap
79+
estimator
80+
:param confidence_level: The confidence level of the confidence
81+
interval estimate
82+
:param need_conditional_estimates: Boolean flag indicating whether
83+
conditional estimates should be computed. Defaults to True if
84+
there are effect modifiers in the graph
85+
:param num_quantiles_to_discretize_cont_cols: The number of quantiles
86+
into which a numeric effect modifier is split, to enable
87+
estimation of conditional treatment effect over it.
88+
:param kwargs: (optional) Additional estimator-specific parameters
6989
:returns: an instance of the estimator class.
70-
7190
"""
7291
self._data = data
7392
self._target_estimand = identified_estimand
@@ -84,14 +103,9 @@ def __init__(self, data, identified_estimand, treatment, outcome,
84103
self._bootstrap_estimates = None # for confidence intervals and std error
85104
self._bootstrap_null_estimates = None # for significance test
86105
self._effect_modifiers = None
87-
self.method_params = params
88-
106+
self.method_params = kwargs
89107
# Setting the default interpret method
90108
self.interpret_method = CausalEstimator.DEFAULT_INTERPRET_METHOD
91-
# Unpacking the keyword arguments
92-
if params is not None:
93-
for key, value in params.items():
94-
setattr(self, key, value)
95109

96110
self.logger = logging.getLogger(__name__)
97111

@@ -114,20 +128,17 @@ def __init__(self, data, identified_estimand, treatment, outcome,
114128
else:
115129
self._effect_modifier_names = None
116130

117-
# Checking if some parameters were set, otherwise setting to default values
118-
if not hasattr(self, 'num_null_simulations'):
119-
self.num_null_simulations = CausalEstimator.DEFAULT_NUMBER_OF_SIMULATIONS_STAT_TEST
120-
if not hasattr(self, 'num_simulations'):
121-
self.num_simulations = CausalEstimator.DEFAULT_NUMBER_OF_SIMULATIONS_CI
122-
if not hasattr(self, 'sample_size_fraction'):
123-
self.sample_size_fraction = CausalEstimator.DEFAULT_SAMPLE_SIZE_FRACTION
124-
if not hasattr(self, 'confidence_level'):
125-
self.confidence_level = CausalEstimator.DEFAULT_CONFIDENCE_LEVEL
126-
if not hasattr(self, 'num_quantiles_to_discretize_cont_cols'):
127-
self.num_quantiles_to_discretize_cont_cols = CausalEstimator.NUM_QUANTILES_TO_DISCRETIZE_CONT_COLS
131+
# Check if some parameters were set, otherwise set to default values
132+
self.num_null_simulations = num_null_simulations
133+
self.num_simulations = num_simulations
134+
self.sample_size_fraction = sample_size_fraction
135+
self.confidence_level = confidence_level
136+
self.num_quantiles_to_discretize_cont_cols = \
137+
num_quantiles_to_discretize_cont_cols
128138
# Estimate conditional estimates by default
129-
if not hasattr(self, 'need_conditional_estimates'):
130-
self.need_conditional_estimates = bool(self._effect_modifier_names)
139+
self.need_conditional_estimates = need_conditional_estimates \
140+
if need_conditional_estimates != 'auto' \
141+
else bool(self._effect_modifier_names)
131142

132143
@staticmethod
133144
def get_estimator_object(new_data, identified_estimand, estimate):
@@ -158,7 +169,7 @@ def get_estimator_object(new_data, identified_estimand, estimate):
158169
confidence_intervals=estimate.params["confidence_intervals"],
159170
target_units=estimate.params["target_units"],
160171
effect_modifiers=estimate.params["effect_modifiers"],
161-
params=estimate.params["method_params"]
172+
**estimate.params["method_params"]
162173
)
163174

164175
return new_estimator
@@ -297,7 +308,6 @@ def _generate_bootstrap_estimates(self, num_bootstrap_simulations,
297308
# Perform the set number of simulations
298309
for index in range(num_bootstrap_simulations):
299310
new_data = resample(self._data, n_samples=sample_size)
300-
301311
new_estimator = type(self)(
302312
new_data,
303313
self._target_estimand,
@@ -310,7 +320,7 @@ def _generate_bootstrap_estimates(self, num_bootstrap_simulations,
310320
confidence_intervals=False,
311321
target_units=self._target_units,
312322
effect_modifiers=self._effect_modifier_names,
313-
params=self.method_params
323+
**self.method_params
314324
)
315325
new_effect = new_estimator.estimate_effect()
316326
simulation_results[index] = new_effect.value
@@ -504,7 +514,7 @@ def _test_significance_with_bootstrap(self, estimate_value, num_null_simulations
504514
confidence_intervals=False,
505515
target_units=self._target_units,
506516
effect_modifiers=self._effect_modifier_names,
507-
params=self.method_params
517+
**self.method_params
508518
)
509519
new_effect = new_estimator.estimate_effect()
510520
null_estimates[i] = new_effect.value

dowhy/causal_estimators/causalml.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,27 @@
77
import causalml
88

99
class Causalml(CausalEstimator):
10-
11-
def __init__(self, *args, **kwargs):
12-
13-
super().__init__(*args, **kwargs)
14-
10+
""" Wrapper class for estimators from the causalml library.
11+
12+
For a list of standard args and kwargs, see documentation for
13+
:class:`~dowhy.causal_estimator.CausalEstimator`.
14+
15+
Supports additional parameters as listed below. For specific
16+
parameters of each estimator, refer to the CausalML docs.
17+
18+
"""
19+
def __init__(self, *args, causalml_methodname, **kwargs):
20+
"""
21+
:param causalml_methodname: Fully qualified name of causalml estimator
22+
class.
23+
"""
24+
# Required to ensure that self.method_params contains all the information
25+
# to create an object of this class
26+
args_dict = {k: v for k, v in locals().items()
27+
if k not in type(self)._STD_INIT_ARGS}
28+
args_dict.update(kwargs)
29+
super().__init__(*args, **args_dict)
30+
self._causalml_methodname = causalml_methodname
1531
# Add the identification method used in the estimator
1632
self.identifier_method = self._target_estimand.identifier_method
1733
self.logger.debug("The identifier method used {}".format(self.identifier_method))

dowhy/causal_estimators/distance_matching_estimator.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,35 @@
55
from dowhy.causal_estimator import CausalEstimate, CausalEstimator
66

77
class DistanceMatchingEstimator(CausalEstimator):
8-
""" Simple matching estimator for binary treatments based on a distance metric.
9-
"""
8+
"""Simple matching estimator for binary treatments based on a distance
9+
metric.
10+
11+
For a list of standard args and kwargs, see documentation for
12+
:class:`~dowhy.causal_estimator.CausalEstimator`.
13+
14+
Supports additional parameters as listed below.
1015
16+
"""
17+
# allowed types of distance metric
1118
Valid_Dist_Metric_Params = ['p', 'V', 'VI', 'w']
12-
def __init__(self, *args, **kwargs):
13-
super().__init__(*args, **kwargs)
19+
20+
def __init__(self, *args, num_matches_per_unit=1,
21+
distance_metric="minkowski", exact_match_cols=None, **kwargs):
22+
"""
23+
:param num_matches_per_unit: The number of matches per data point.
24+
Default=1.
25+
:param distance_metric: Distance metric to use. Default="minkowski"
26+
that corresponds to Euclidean distance metric with p=2.
27+
:param exact_match_cols: List of column names whose values should be
28+
exactly matched. Typically used for columns with discrete values.
29+
30+
"""
31+
# Required to ensure that self.method_params contains all the
32+
# parameters to create an object of this class
33+
args_dict = {k: v for k, v in locals().items()
34+
if k not in type(self)._STD_INIT_ARGS}
35+
args_dict.update(kwargs)
36+
super().__init__(*args, **args_dict)
1437
# Check if the treatment is one-dimensional
1538
if len(self._treatment_name) > 1:
1639
error_msg = str(self.__class__) + "cannot handle more than one treatment variable"
@@ -21,15 +44,9 @@ def __init__(self, *args, **kwargs):
2144
self.logger.error(error_msg)
2245
raise Exception(error_msg)
2346

24-
# Setting the number of matches per data point
25-
if getattr(self, 'num_matches_per_unit', None) is None:
26-
self.num_matches_per_unit = 1
27-
# Default distance metric if not provided by the user
28-
if getattr(self, 'distance_metric', None) is None:
29-
self.distance_metric = 'minkowski' # corresponds to euclidean metric with p=2
30-
31-
if getattr(self, 'exact_match_cols', None) is None:
32-
self.exact_match_cols = None
47+
self.num_matches_per_unit = num_matches_per_unit
48+
self.distance_metric = distance_metric
49+
self.exact_match_cols = exact_match_cols
3350

3451
self.logger.debug("Back-door variables used:" +
3552
",".join(self._target_estimand.get_backdoor_variables()))

dowhy/causal_estimators/econml.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,27 @@
1111

1212

1313
class Econml(CausalEstimator):
14+
"""Wrapper class for estimators from the EconML library.
1415
15-
def __init__(self, *args, **kwargs):
16-
super().__init__(*args, **kwargs)
16+
For a list of standard args and kwargs, see documentation for
17+
:class:`~dowhy.causal_estimator.CausalEstimator`.
18+
19+
Supports additional parameters as listed below. For init and fit
20+
parameters of each estimator, refer to the EconML docs.
21+
22+
"""
23+
def __init__(self, *args, econml_methodname, **kwargs):
24+
"""
25+
:param econml_methodname: Fully qualified name of econml estimator
26+
class. For example, 'econml.dml.DML'
27+
"""
28+
# Required to ensure that self.method_params contains all the
29+
# parameters to create an object of this class
30+
args_dict = {k: v for k, v in locals().items()
31+
if k not in type(self)._STD_INIT_ARGS}
32+
args_dict.update(kwargs)
33+
super().__init__(*args, **args_dict)
34+
self._econml_methodname = econml_methodname
1735
self.logger.info("INFO: Using EconML Estimator")
1836
self.identifier_method = self._target_estimand.identifier_method
1937
self._observed_common_causes_names = self._target_estimand.get_backdoor_variables().copy()
@@ -154,7 +172,7 @@ def construct_symbolic_estimator(self, estimand):
154172
expr += "+".join(var_list)
155173
expr += " | " + ",".join(self._effect_modifier_names)
156174
return expr
157-
175+
158176
def shap_values(self, df: pd.DataFrame, *args, **kwargs):
159177
return self.estimator.shap_values(
160178
df[self._effect_modifier_names].values, *args, **kwargs

dowhy/causal_estimators/generalized_linear_model_estimator.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import numpy as np
2-
import pandas as pd
31
import statsmodels.api as sm
42
import itertools
53

64
from dowhy.causal_estimators.regression_estimator import RegressionEstimator
75

6+
87
class GeneralizedLinearModelEstimator(RegressionEstimator):
98
"""Compute effect of treatment using a generalized linear model such as logistic regression.
109
@@ -13,16 +12,29 @@ class GeneralizedLinearModelEstimator(RegressionEstimator):
1312
1413
"""
1514

16-
def __init__(self, *args, **kwargs):
17-
super().__init__(*args, **kwargs)
15+
def __init__(self, *args, glm_family=None, predict_score=True, **kwargs):
16+
"""For a list of args and kwargs, see documentation for
17+
:class:`~dowhy.causal_estimator.CausalEstimator`.
18+
19+
:param glm_family: statsmodels family for the generalized linear model.
20+
For example, use statsmodels.api.families.Binomial() for logistic
21+
regression or statsmodels.api.families.Poisson() for count data.
22+
:param predict_score: For models that have a binary output, whether
23+
to output the model's score or the binary output based on the score.
24+
25+
"""
26+
# Required to ensure that self.method_params contains all the
27+
# parameters needed to create an object of this class
28+
args_dict = {k: v for k, v in locals().items()
29+
if k not in type(self)._STD_INIT_ARGS}
30+
args_dict.update(kwargs)
31+
super().__init__(*args, **args_dict)
1832
self.logger.info("INFO: Using Generalized Linear Model Estimator")
19-
if self.method_params is not None and 'glm_family' in self.method_params:
20-
self.family = self.method_params['glm_family']
33+
if glm_family is not None:
34+
self.family = glm_family
2135
else:
2236
raise ValueError("Need to specify the family for the generalized linear model. Provide a 'glm_family' parameter in method_params, such as statsmodels.api.families.Binomial() for logistic regression.")
23-
self.predict_score = True
24-
if self.method_params is not None and 'predict_score' in self.method_params:
25-
self.predict_score = self.method_params['predict_score']
37+
self.predict_score = predict_score
2638
# Checking if Y is binary
2739
outcome_values = self._data[self._outcome_name].astype(int).unique()
2840
self.outcome_is_binary = all([v in [0,1] for v in outcome_values])

0 commit comments

Comments
 (0)