Feat: Implement Kernel Riesz ATT Estimator with CV and DML#7
Feat: Implement Kernel Riesz ATT Estimator with CV and DML#7apoorvalal wants to merge 2 commits intomasterfrom
Conversation
This commit introduces a flexible `KernelRieszATT` estimator for Average Treatment Effect on the Treated (ATT) estimation using kernel-based Riesz representer methods. Key features include: - Support for various kernel functions (e.g., RBF, linear). - Automatic hyperparameter selection for kernel parameters (e.g., gamma) and regularization strength via cross-validation, using Maximum Mean Discrepancy (MMD) as the default scoring metric. - Optional Double Machine Learning (DML) through cross-fitting for more robust ATT estimates. - Refactored existing ATT estimation methods into a common `ATTWeightEstimator` base class for better modularity. - Updated simulation script to demonstrate and test the new estimator with its various configurations (CV, DML, different kernels). - Added comprehensive docstrings and API refinements for all new and modified components.
There was a problem hiding this comment.
Pull Request Overview
This PR adds a new kernel-based Riesz ATT estimator with hyperparameter CV and optional DML, refactors existing ATT estimators into a shared base class, and provides an extended example notebook demonstrating all estimators.
- Introduce
KernelRieszATTwith kernel selection, CV-based tuning, and DML cross-fitting. - Refactor penalized, constrained, and regression-form ATT estimators under
ATTWeightEstimator. - Provide an example notebook showing usage, parameter grids, and results comparison.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| examples/riesz_representer.ipynb | Added a demo notebook showcasing KernelRieszATT and other estimators |
| cbpys/riesz_estimators.py | Implemented PenalizedFormEstimator, ConstrainedFormEstimator, and RieszRegressionFormEstimator |
| cbpys/kernel.py | Implemented KernelRieszATT with CV hyperparameter search and DML support |
| cbpys/baseclass.py | Added ATTWeightEstimator abstract base class |
Comments suppressed due to low confidence (2)
cbpys/kernel.py:89
- [nitpick] The estimator currently only warns about unexpected kwargs like 'kernel_args' and 'reg_param'. Consider officially accepting these legacy aliases or raising a clear
TypeErrorso users know the correct argument names.
)
cbpys/kernel.py:9
- [nitpick] Add unit tests covering the CV parameter search (including edge cases), fallback to defaults, and DML cross-fitting logic to ensure all branches behave as intended.
class KernelRieszATT(ATTWeightEstimator):
| # Fit weights FOR THE VALIDATION FOLD using current params | ||
| # These weights are specific to X_c_val_fold to balance X_t_val_fold | ||
| weights_val_fold = self._fit_single_config( | ||
| X_t_val_fold, | ||
| X_c_val_fold, |
There was a problem hiding this comment.
In the CV loop you fit weights on the validation splits instead of the training splits. To properly evaluate hyperparameters, use the training folds to fit (train_idx) and then score on the corresponding validation folds (val_idx).
| # Fit weights FOR THE VALIDATION FOLD using current params | |
| # These weights are specific to X_c_val_fold to balance X_t_val_fold | |
| weights_val_fold = self._fit_single_config( | |
| X_t_val_fold, | |
| X_c_val_fold, | |
| # Fit weights using training data | |
| weights_train_fold = self._fit_single_config( | |
| X_t_train_fold, | |
| X_c_train_fold, |
| self.y_control = y_control | ||
| self.sum_weights = np.sum(self.weights) | ||
| if abs(self.sum_weights) < 1e-9: # Check for very small sum of weights | ||
| print( |
There was a problem hiding this comment.
Use warnings.warn instead of print for warning messages so users can filter or redirect warnings appropriately.
| print( | |
| warnings.warn( |
This commit introduces a flexible
KernelRieszATTestimator for Average Treatment Effect on the Treated (ATT) estimation using kernel-based Riesz representer methods.Key features include:
ATTWeightEstimatorbase class for better modularity.