Skip to content

Commit 7975c30

Browse files
jeongyoonleeclaude
andauthored
Fix CausalRandomForestRegressor predicting inf from division by zero (#589) (#883)
* Add .worktrees/ to .gitignore * Fix CausalRandomForestRegressor predicting inf from division by zero (#589) Guard against zero treatment/control counts in CausalMSE and TTest criterion functions. When a tree split creates a child node with no treatment or no control observations, the variance formula `var/count` produces infinity. Now skips impurity contribution for that treatment group (zero impurity), preventing the splitter from favoring degenerate splits. Affected methods: - CausalMSE.node_impurity() - CausalMSE.children_impurity() - TTest.children_impurity() * Add regression test for inf predictions with sparse groups (#589) Test that CausalRandomForestRegressor.predict() returns finite values when imbalanced data causes zero-count treatment/control nodes. * Add ttest criterion regression test for inf predictions (#589) * Fix ttest criterion name: 'ttest' -> 't_test' --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 123cc1f commit 7975c30

2 files changed

Lines changed: 64 additions & 8 deletions

File tree

causalml/inference/tree/causal/_criterion.pyx

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,8 @@ cdef class CausalMSE(CausalRegressionCriterion):
463463
tr_var = self.state.node.outcome_var(tr_group_idx)
464464
tr_count = self.state.node.count_1d[tr_group_idx]
465465

466-
impurity += (tr_var / tr_count + ct_var / ct_count) - node_tau * node_tau
466+
if tr_count > 0 and ct_count > 0:
467+
impurity += (tr_var / tr_count + ct_var / ct_count) - node_tau * node_tau
467468

468469
impurity /= (self.n_outputs - 1)
469470
impurity += self.get_groups_penalty(self.state.node)
@@ -500,8 +501,10 @@ cdef class CausalMSE(CausalRegressionCriterion):
500501
left_tr_var = self.state.left.outcome_var(tr_group_idx)
501502
left_tr_count = self.state.left.count_1d[tr_group_idx]
502503

503-
impurity_right[0] += (right_tr_var / right_tr_count + right_ct_var / right_ct_count) - right_tau * right_tau
504-
impurity_left[0] += (left_tr_var / left_tr_count + left_ct_var / left_ct_count) - left_tau * left_tau
504+
if right_tr_count > 0 and right_ct_count > 0:
505+
impurity_right[0] += (right_tr_var / right_tr_count + right_ct_var / right_ct_count) - right_tau * right_tau
506+
if left_tr_count > 0 and left_ct_count > 0:
507+
impurity_left[0] += (left_tr_var / left_tr_count + left_ct_var / left_ct_count) - left_tau * left_tau
505508

506509
impurity_right[0] /= (self.n_outputs - 1)
507510
impurity_left[0] /= (self.n_outputs - 1)
@@ -577,16 +580,22 @@ cdef class TTest(CausalRegressionCriterion):
577580
left_tr_var = self.state.left.outcome_var(tr_group_idx)
578581
left_tr_count = self.state.left.count_1d[tr_group_idx]
579582

580-
denom_left = sqrt(left_tr_var / left_tr_count + left_ct_var / left_ct_count)
581-
denom_right = sqrt(right_tr_var / right_tr_count + right_ct_var / right_ct_count)
583+
denom_left = 0.0
584+
denom_right = 0.0
585+
if left_tr_count > 0 and left_ct_count > 0:
586+
denom_left = sqrt(left_tr_var / left_tr_count + left_ct_var / left_ct_count)
587+
if right_tr_count > 0 and right_ct_count > 0:
588+
denom_right = sqrt(right_tr_var / right_tr_count + right_ct_var / right_ct_count)
582589
if denom_left > 0.:
583590
t_left_sum += left_tau / denom_left
584591
if denom_right > 0.:
585592
t_right_sum += right_tau / denom_right
586-
593+
587594
# Per-treatment squared difference in taus between sides
588-
inv_n_sum = (1.0 / right_tr_count + 1.0 / right_ct_count +
589-
1.0 / left_tr_count + 1.0 / left_ct_count)
595+
inv_n_sum = 0.0
596+
if right_tr_count > 0 and right_ct_count > 0 and left_tr_count > 0 and left_ct_count > 0:
597+
inv_n_sum = (1.0 / right_tr_count + 1.0 / right_ct_count +
598+
1.0 / left_tr_count + 1.0 / left_ct_count)
590599

591600
# Pooled variance across four cells (left/right × tr/ct)
592601
pooled_var_t = 0.0

tests/test_causal_trees.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,50 @@ def test_unbiased_sampling_error(
275275
crforest_test_var = crforest.calculate_error(X_train=X_train, X_test=X_test)
276276
assert (crforest_test_var > 0).all()
277277
assert crforest_test_var.shape[0] == y_test.shape[0]
278+
279+
280+
def test_CausalRandomForestRegressor_no_inf_predictions():
281+
"""Test that CausalRandomForestRegressor does not predict inf values
282+
when some tree splits have zero-count treatment/control groups (#589)."""
283+
np.random.seed(RANDOM_SEED)
284+
n = 100
285+
X = np.random.randn(n, 5)
286+
# Heavily imbalanced: very few treated samples so tree splits
287+
# can produce nodes with zero treatment count
288+
treatment = np.array([0] * 90 + [1] * 10)
289+
y = np.random.randn(n)
290+
291+
model = CausalRandomForestRegressor(
292+
criterion="causal_mse",
293+
control_name=0,
294+
n_estimators=10,
295+
min_samples_leaf=1,
296+
random_state=RANDOM_SEED,
297+
)
298+
model.fit(X=X, treatment=treatment, y=y)
299+
preds = model.predict(X=X)
300+
301+
assert np.all(np.isfinite(preds)), "Predictions contain inf or NaN values"
302+
303+
304+
def test_CausalRandomForestRegressor_no_inf_predictions_ttest():
305+
"""Test that CausalRandomForestRegressor with criterion='ttest' does not
306+
predict inf values when some tree splits have zero-count
307+
treatment/control groups (#589)."""
308+
np.random.seed(RANDOM_SEED)
309+
n = 100
310+
X = np.random.randn(n, 5)
311+
treatment = np.array([0] * 90 + [1] * 10)
312+
y = np.random.randn(n)
313+
314+
model = CausalRandomForestRegressor(
315+
criterion="t_test",
316+
control_name=0,
317+
n_estimators=10,
318+
min_samples_leaf=1,
319+
random_state=RANDOM_SEED,
320+
)
321+
model.fit(X=X, treatment=treatment, y=y)
322+
preds = model.predict(X=X)
323+
324+
assert np.all(np.isfinite(preds)), "Predictions contain inf or NaN values"

0 commit comments

Comments
 (0)