Skip to content

Commit 4429fba

Browse files
committed
Add test for uniform weights to ensure consistency with unweighted ATE
1 parent 783d503 commit 4429fba

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

tests/test_debiased_ate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ def test_debiased_ate_bad_weight_shapes(Debiaser, no_noise_data):
5757
# wrong-shaped weight for control
5858
with pytest.raises(ValueError):
5959
deb.debiased_ate(treated, control, iptw_control=np.ones(5))
60+
61+
62+
@pytest.mark.parametrize("Debiaser", [TweedieDebiaser, LccDebiaser])
63+
def test_debiased_ate_uniform_weights_equal_unweighted(Debiaser, no_noise_data):
64+
"""If IPTW weights are uniform (all equal), result should equal unweighted ATE."""
65+
preds, targets = no_noise_data
66+
deb = Debiaser().fit(preds, targets)
67+
68+
treated = preds[:40]
69+
control = preds[40:90]
70+
71+
# uniform weights (all ones) and a constant non-one weight should both match unweighted
72+
ones_t = np.ones(len(treated))
73+
ones_c = np.ones(len(control))
74+
75+
const_t = np.full(len(treated), 2.0)
76+
const_c = np.full(len(control), 2.0)
77+
78+
unweighted = deb.debiased_ate(treated, control)
79+
ones_weighted = deb.debiased_ate(treated, control, iptw_treated=ones_t, iptw_control=ones_c)
80+
const_weighted = deb.debiased_ate(treated, control, iptw_treated=const_t, iptw_control=const_c)
81+
82+
assert np.isclose(unweighted, ones_weighted)
83+
assert np.isclose(unweighted, const_weighted)

0 commit comments

Comments
 (0)