Skip to content

Commit ac8a3af

Browse files
jeongyoonleeclaude
andcommitted
Address Copilot PR review comments
- Improve docstring completeness with detailed parameter descriptions - Clarify usage of treatment/y parameters for classification metrics - Explain return_components output format with probability details - Separate DR learner imports for better readability - Maintain API consistency with other meta-learners Addresses comments in #844 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 210c2fe commit ac8a3af

2 files changed

Lines changed: 16 additions & 4 deletions

File tree

causalml/inference/meta/drlearner.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,21 @@ def predict(
531531
532532
Args:
533533
X (np.matrix or np.array or pd.Dataframe): a feature matrix
534-
treatment (np.array or pd.Series, optional): a treatment vector
535-
y (np.array or pd.Series, optional): an outcome vector
536-
verbose (bool, optional): whether to output progress logs
534+
treatment (np.array or pd.Series, optional): a treatment vector. Used for computing
535+
classification metrics when y is also provided.
536+
y (np.array or pd.Series, optional): an outcome vector. Used for computing
537+
classification metrics when treatment is also provided.
538+
p (np.ndarray or pd.Series or dict, optional): an array of propensity scores of float (0,1) in the
539+
single-treatment case; or, a dictionary of treatment groups that map to propensity vectors of
540+
float (0,1). Currently not used in prediction but kept for API consistency.
541+
return_components (bool, optional): whether to return outcome probabilities for treatment and control
542+
groups separately. Defaults to False.
543+
verbose (bool, optional): whether to output progress logs. Defaults to True.
537544
Returns:
538545
(numpy.ndarray): Predictions of treatment effects.
546+
If return_components is True, also returns:
547+
- dict: Predicted probabilities for the control group (yhat_cs).
548+
- dict: Predicted probabilities for the treatment group (yhat_ts).
539549
"""
540550
X, treatment, y = convert_pd_to_np(X, treatment, y)
541551

tests/test_meta_learners.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
XGBRRegressor,
3131
)
3232
from causalml.inference.meta import TMLELearner
33-
from causalml.inference.meta import BaseDRLearner, BaseDRRegressor, BaseDRClassifier
33+
from causalml.inference.meta import BaseDRLearner
34+
from causalml.inference.meta import BaseDRRegressor
35+
from causalml.inference.meta import BaseDRClassifier
3436
from causalml.metrics import ape, auuc_score
3537

3638
from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION

0 commit comments

Comments
 (0)