Skip to content

Commit 77b9f06

Browse files
test: broaden competing risks coverage (#67)
1 parent 9961ad2 commit 77b9f06

File tree

3 files changed

+199
-19
lines changed

3 files changed

+199
-19
lines changed

tests/test_cmm.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,90 @@
11
import os
22
import sys
33

4+
import numpy as np
5+
import pandas as pd
6+
47
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5-
from gen_surv.cmm import gen_cmm
8+
from gen_surv.cmm import gen_cmm, generate_event_times
9+
610

11+
def test_generate_event_times_reproducible():
12+
np.random.seed(0)
13+
result = generate_event_times(
14+
z1=1.0,
15+
beta=[0.1, 0.2, 0.3],
16+
rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
17+
)
18+
assert np.isclose(result["t12"], 0.7201370350469476)
19+
assert np.isclose(result["t13"], 1.0282691393768246)
20+
assert np.isclose(result["t23"], 0.6839405281667484)
721

8-
def test_gen_cmm_shape():
22+
23+
def test_gen_cmm_uniform_reproducible():
24+
np.random.seed(42)
925
df = gen_cmm(
10-
n=50,
26+
n=5,
1127
model_cens="uniform",
1228
cens_par=1.0,
1329
beta=[0.1, 0.2, 0.3],
1430
covariate_range=2.0,
15-
rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0],
31+
rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
32+
)
33+
expected = pd.DataFrame(
34+
{
35+
"id": [1, 2, 3, 4, 5],
36+
"start": [0.0] * 5,
37+
"stop": [
38+
0.019298197410170713,
39+
0.05808361216819946,
40+
0.5550989864862181,
41+
0.2117537394012932,
42+
0.19451374567187332,
43+
],
44+
"status": [1, 0, 1, 1, 1],
45+
"X0": [
46+
0.749080237694725,
47+
1.9014286128198323,
48+
1.4639878836228102,
49+
1.1973169683940732,
50+
0.31203728088487304,
51+
],
52+
"transition": [1.0, float("nan"), 2.0, 1.0, 1.0],
53+
}
54+
)
55+
pd.testing.assert_frame_equal(df, expected)
56+
57+
58+
def test_gen_cmm_exponential_reproducible():
59+
np.random.seed(42)
60+
df = gen_cmm(
61+
n=5,
62+
model_cens="exponential",
63+
cens_par=1.0,
64+
beta=[0.1, 0.2, 0.3],
65+
covariate_range=2.0,
66+
rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
67+
)
68+
expected = pd.DataFrame(
69+
{
70+
"id": [1, 2, 3, 4, 5],
71+
"start": [0.0] * 5,
72+
"stop": [
73+
0.019298197410170713,
74+
0.059838768608680676,
75+
0.5550989864862181,
76+
0.2117537394012932,
77+
0.19451374567187332,
78+
],
79+
"status": [1, 0, 1, 1, 1],
80+
"X0": [
81+
0.749080237694725,
82+
1.9014286128198323,
83+
1.4639878836228102,
84+
1.1973169683940732,
85+
0.31203728088487304,
86+
],
87+
"transition": [1.0, float("nan"), 2.0, 1.0, 1.0],
88+
}
1689
)
17-
assert df.shape[1] == 6
18-
assert "transition" in df.columns
90+
pd.testing.assert_frame_equal(df, expected)

tests/test_competing_risks.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
"""
2-
Tests for Competing Risks models.
3-
"""
1+
"""Tests for Competing Risks models."""
2+
3+
import os
44

55
import numpy as np
66
import pandas as pd
@@ -16,6 +16,8 @@
1616
gen_competing_risks_weibull,
1717
)
1818

19+
os.environ.setdefault("MPLBACKEND", "Agg")
20+
1921

2022
def test_gen_competing_risks_basic():
2123
"""Test that the competing risks generator runs without errors."""
@@ -83,6 +85,13 @@ def test_competing_risks_parameters():
8385
gen_competing_risks(n=10, n_risks=2, model_cens="invalid", seed=42)
8486

8587

88+
def test_invalid_covariate_dist():
89+
with pytest.raises(ChoiceError):
90+
gen_competing_risks(n=5, n_risks=2, covariate_dist="unknown", seed=1)
91+
with pytest.raises(ChoiceError):
92+
gen_competing_risks_weibull(n=5, n_risks=2, covariate_dist="unknown", seed=1)
93+
94+
8695
def test_competing_risks_weibull_parameters():
8796
"""Test parameter validation in Weibull competing risks model."""
8897
# Test with invalid number of shape parameters
@@ -126,6 +135,18 @@ def test_cause_specific_cumulative_incidence():
126135
cause_specific_cumulative_incidence(df, time_points, cause=3)
127136

128137

138+
def test_cause_specific_cumulative_incidence_bounds():
139+
df = gen_competing_risks(n=30, n_risks=2, seed=5)
140+
max_time = df["time"].max()
141+
time_points = [-1.0, 0.0, max_time + 1]
142+
cif = cause_specific_cumulative_incidence(df, time_points, cause=1)
143+
assert cif.iloc[0]["incidence"] == 0.0
144+
expected = cause_specific_cumulative_incidence(df, [max_time], cause=1).iloc[0][
145+
"incidence"
146+
]
147+
assert cif.iloc[-1]["incidence"] == expected
148+
149+
129150
@given(
130151
n=st.integers(min_value=5, max_value=50),
131152
n_risks=st.integers(min_value=2, max_value=4),
@@ -170,6 +191,31 @@ def test_competing_risks_weibull_properties(n, n_risks, seed):
170191
assert len(status_counts) >= 2
171192

172193

194+
def test_gen_competing_risks_forces_event_types():
195+
df = gen_competing_risks(
196+
n=2,
197+
n_risks=2,
198+
baseline_hazards=[1e-9, 1e-9],
199+
model_cens="uniform",
200+
cens_par=0.1,
201+
seed=0,
202+
)
203+
assert set(df["status"]) == {1, 2}
204+
205+
206+
def test_gen_competing_risks_weibull_forces_event_types():
207+
df = gen_competing_risks_weibull(
208+
n=2,
209+
n_risks=2,
210+
shape_params=[1, 1],
211+
scale_params=[1e9, 1e9],
212+
model_cens="uniform",
213+
cens_par=0.1,
214+
seed=0,
215+
)
216+
assert set(df["status"]) == {1, 2}
217+
218+
173219
def test_reproducibility():
174220
"""Test that results are reproducible with the same seed."""
175221
df1 = gen_competing_risks(n=20, n_risks=2, seed=42)
@@ -202,14 +248,10 @@ def test_competing_risks_summary_with_categorical():
202248
assert "distribution" in summary["covariate_stats"]["group"]
203249

204250

205-
import matplotlib
206-
207-
matplotlib.use("Agg")
208-
209-
210251
def test_plot_cause_specific_hazards_runs():
252+
plt = pytest.importorskip("matplotlib.pyplot")
211253
df = gen_competing_risks(n=30, n_risks=2, seed=3)
212254
fig, ax = cr.plot_cause_specific_hazards(df, time_points=np.linspace(0, 5, 5))
213255
assert hasattr(fig, "savefig")
214256
assert len(ax.get_lines()) >= 1
215-
matplotlib.pyplot.close(fig)
257+
plt.close(fig)

tests/test_export.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,80 @@
11
import pandas as pd
22
import pyreadr
3+
import pytest
34

5+
from gen_surv._validation import ChoiceError
46
from gen_surv.export import export_dataset
57

68

7-
def test_export_dataset_rds(tmp_path):
9+
@pytest.mark.parametrize(
10+
"fmt, reader",
11+
[
12+
("csv", pd.read_csv),
13+
("feather", pd.read_feather),
14+
("ft", pd.read_feather),
15+
],
16+
)
17+
def test_export_dataset_formats(fmt, reader, tmp_path):
818
df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]})
9-
out = tmp_path / "data.rds"
19+
out = tmp_path / f"data.{fmt}"
1020
export_dataset(df, out)
1121
assert out.exists()
12-
result = pyreadr.read_r(out)[None]
13-
result = result.astype(df.dtypes.to_dict())
22+
result = reader(out).astype(df.dtypes.to_dict())
1423
pd.testing.assert_frame_equal(result.reset_index(drop=True), df)
24+
25+
26+
def test_export_dataset_json(monkeypatch, tmp_path):
27+
df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]})
28+
out = tmp_path / "data.json"
29+
30+
called = {}
31+
32+
def fake_to_json(self, path, orient="table"):
33+
called["args"] = (path, orient)
34+
with open(path, "w", encoding="utf-8") as f:
35+
f.write("{}")
36+
37+
monkeypatch.setattr(pd.DataFrame, "to_json", fake_to_json)
38+
export_dataset(df, out)
39+
assert called["args"] == (out, "table")
40+
assert out.exists()
41+
42+
43+
def test_export_dataset_rds(monkeypatch, tmp_path):
44+
df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]})
45+
out = tmp_path / "data.rds"
46+
47+
captured = {}
48+
49+
def fake_write_rds(path, data):
50+
captured["path"] = path
51+
captured["data"] = data
52+
open(path, "wb").close()
53+
54+
monkeypatch.setattr(pyreadr, "write_rds", fake_write_rds)
55+
export_dataset(df, out)
56+
assert out.exists()
57+
pd.testing.assert_frame_equal(captured["data"], df.reset_index(drop=True))
58+
59+
60+
def test_export_dataset_explicit_fmt(monkeypatch, tmp_path):
61+
df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]})
62+
out = tmp_path / "data.bin"
63+
64+
called = {}
65+
66+
def fake_to_json(self, path, orient="table"):
67+
called["args"] = (path, orient)
68+
with open(path, "w", encoding="utf-8") as f:
69+
f.write("{}")
70+
71+
monkeypatch.setattr(pd.DataFrame, "to_json", fake_to_json)
72+
export_dataset(df, out, fmt="json")
73+
assert called["args"] == (out, "table")
74+
assert out.exists()
75+
76+
77+
def test_export_dataset_invalid_format(tmp_path):
78+
df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]})
79+
with pytest.raises(ChoiceError):
80+
export_dataset(df, tmp_path / "data.xxx", fmt="txt")

0 commit comments

Comments
 (0)