Skip to content

Commit d0e57ec

Browse files
committed
add files
1 parent d1a703c commit d0e57ec

1 file changed

Lines changed: 52 additions & 0 deletions

File tree

tests/test_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
from unshrink import TweedieDebiaser, LccDebiaser
3+
from unshrink.utils import evaluate_debiaser
4+
5+
6+
def test_evaluate_debiaser_tweedie():
7+
"""Test evaluate_debiaser returns expected keys and reduces bias."""
8+
rng = np.random.default_rng(42)
9+
n = 500
10+
noise = 0.1
11+
12+
cal_preds = rng.random(n)
13+
cal_targets = cal_preds + rng.normal(0, noise, n)
14+
preds = rng.random(n)
15+
targets = preds + rng.normal(0, noise, n)
16+
17+
result = evaluate_debiaser(
18+
TweedieDebiaser(),
19+
cal_preds, cal_targets,
20+
preds, targets
21+
)
22+
23+
assert "true_mean" in result
24+
assert "naive_mean" in result
25+
assert "corrected_mean" in result
26+
assert "bias_before" in result
27+
assert "bias_after" in result
28+
29+
assert np.isclose(result["true_mean"], np.mean(targets))
30+
assert np.isclose(result["naive_mean"], np.mean(preds))
31+
32+
33+
def test_evaluate_debiaser_lcc():
34+
"""Test evaluate_debiaser works with LccDebiaser."""
35+
rng = np.random.default_rng(42)
36+
n = 500
37+
noise = 0.1
38+
39+
cal_preds = rng.random(n)
40+
cal_targets = cal_preds + rng.normal(0, noise, n)
41+
preds = rng.random(n)
42+
targets = preds + rng.normal(0, noise, n)
43+
44+
result = evaluate_debiaser(
45+
LccDebiaser(),
46+
cal_preds, cal_targets,
47+
preds, targets
48+
)
49+
50+
assert "true_mean" in result
51+
assert "corrected_mean" in result
52+
assert isinstance(result["bias_after"], float)

0 commit comments

Comments
 (0)