Skip to content

Feat: Implement Kernel Riesz ATT Estimator with CV and DML#7

Open
apoorvalal wants to merge 2 commits intomasterfrom
feat/kernel-riesz-att-dml
Open

Feat: Implement Kernel Riesz ATT Estimator with CV and DML#7
apoorvalal wants to merge 2 commits intomasterfrom
feat/kernel-riesz-att-dml

Conversation

@apoorvalal
Copy link
Copy Markdown
Owner

This commit introduces a flexible KernelRieszATT estimator for Average Treatment Effect on the Treated (ATT) estimation using kernel-based Riesz representer methods.

Key features include:

  • Support for various kernel functions (e.g., RBF, linear).
  • Automatic hyperparameter selection for kernel parameters (e.g., gamma) and regularization strength via cross-validation, using Maximum Mean Discrepancy (MMD) as the default scoring metric.
  • Optional Double Machine Learning (DML) through cross-fitting for more robust ATT estimates.
  • Refactored existing ATT estimation methods into a common ATTWeightEstimator base class for better modularity.
  • Updated simulation script to demonstrate and test the new estimator with its various configurations (CV, DML, different kernels).
  • Added comprehensive docstrings and API refinements for all new and modified components.

google-labs-jules bot and others added 2 commits June 21, 2025 20:22
This commit introduces a flexible `KernelRieszATT` estimator for
Average Treatment Effect on the Treated (ATT) estimation using
kernel-based Riesz representer methods.

Key features include:
- Support for various kernel functions (e.g., RBF, linear).
- Automatic hyperparameter selection for kernel parameters (e.g., gamma)
  and regularization strength via cross-validation, using Maximum Mean
  Discrepancy (MMD) as the default scoring metric.
- Optional Double Machine Learning (DML) through cross-fitting for
  more robust ATT estimates.
- Refactored existing ATT estimation methods into a common
  `ATTWeightEstimator` base class for better modularity.
- Updated simulation script to demonstrate and test the new estimator
  with its various configurations (CV, DML, different kernels).
- Added comprehensive docstrings and API refinements for all new and
  modified components.
@apoorvalal apoorvalal requested a review from Copilot July 13, 2025 03:21
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new kernel-based Riesz ATT estimator with hyperparameter CV and optional DML, refactors existing ATT estimators into a shared base class, and provides an extended example notebook demonstrating all estimators.

  • Introduce KernelRieszATT with kernel selection, CV-based tuning, and DML cross-fitting.
  • Refactor penalized, constrained, and regression-form ATT estimators under ATTWeightEstimator.
  • Provide an example notebook showing usage, parameter grids, and results comparison.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
examples/riesz_representer.ipynb Added a demo notebook showcasing KernelRieszATT and other estimators
cbpys/riesz_estimators.py Implemented PenalizedFormEstimator, ConstrainedFormEstimator, and RieszRegressionFormEstimator
cbpys/kernel.py Implemented KernelRieszATT with CV hyperparameter search and DML support
cbpys/baseclass.py Added ATTWeightEstimator abstract base class
Comments suppressed due to low confidence (2)

cbpys/kernel.py:89

  • [nitpick] The estimator currently only warns about unexpected kwargs like 'kernel_args' and 'reg_param'. Consider officially accepting these legacy aliases or raising a clear TypeError so users know the correct argument names.
            )

cbpys/kernel.py:9

  • [nitpick] Add unit tests covering the CV parameter search (including edge cases), fallback to defaults, and DML cross-fitting logic to ensure all branches behave as intended.
class KernelRieszATT(ATTWeightEstimator):

Comment thread cbpys/kernel.py
Comment on lines +346 to +350
# Fit weights FOR THE VALIDATION FOLD using current params
# These weights are specific to X_c_val_fold to balance X_t_val_fold
weights_val_fold = self._fit_single_config(
X_t_val_fold,
X_c_val_fold,
Copy link

Copilot AI Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the CV loop you fit weights on the validation splits instead of the training splits. To properly evaluate hyperparameters, use the training folds to fit (train_idx) and then score on the corresponding validation folds (val_idx).

Suggested change
# Fit weights FOR THE VALIDATION FOLD using current params
# These weights are specific to X_c_val_fold to balance X_t_val_fold
weights_val_fold = self._fit_single_config(
X_t_val_fold,
X_c_val_fold,
# Fit weights using training data
weights_train_fold = self._fit_single_config(
X_t_train_fold,
X_c_train_fold,

Copilot uses AI. Check for mistakes.
Comment thread cbpys/riesz_estimators.py
self.y_control = y_control
self.sum_weights = np.sum(self.weights)
if abs(self.sum_weights) < 1e-9: # Check for very small sum of weights
print(
Copy link

Copilot AI Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use warnings.warn instead of print for warning messages so users can filter or redirect warnings appropriately.

Suggested change
print(
warnings.warn(

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants