Skip to content

Commit 74b4021

Browse files
Increase test coverage (#52)
1 parent 9106058 commit 74b4021

File tree

6 files changed

+238
-9
lines changed

6 files changed

+238
-9
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- name: Build documentation
2222
run: poetry run sphinx-build -W -b html docs/source docs/build
2323
- name: Upload artifacts
24-
uses: actions/upload-artifact@v3
24+
uses: actions/upload-artifact@v4
2525
with:
2626
name: documentation
2727
path: docs/build/

tests/test_cli.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import runpy
33
import sys
4+
import pytest
45

56
import pandas as pd
67

78
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
8-
from gen_surv.cli import dataset
9+
from gen_surv.cli import dataset, visualize
910

1011

1112
def test_cli_dataset_stdout(monkeypatch, capsys):
@@ -81,7 +82,9 @@ def fake_generate(**kwargs):
8182
return pd.DataFrame({"time": [1], "status": [0]})
8283

8384
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
84-
dataset(model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None)
85+
dataset(
86+
model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None
87+
)
8588
assert captured["model"] == "aft_weibull"
8689
assert captured["beta"] == [0.1, 0.2]
8790
assert captured["shape"] == 1.1
@@ -146,3 +149,109 @@ def fake_generate(**kwargs):
146149
assert captured["betas_survival"] == [0.4]
147150
assert captured["betas_cure"] == [0.4]
148151

152+
153+
def test_dataset_invalid_model(monkeypatch):
154+
def fake_generate(**kwargs):
155+
raise ValueError("bad model")
156+
157+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
158+
with pytest.raises(ValueError):
159+
dataset(model="nope", n=1, output=None)
160+
161+
162+
def test_cli_visualize_basic(monkeypatch, tmp_path):
163+
csv = tmp_path / "data.csv"
164+
pd.DataFrame({"time": [1, 2], "status": [1, 0]}).to_csv(csv, index=False)
165+
166+
def fake_plot_survival_curve(**kwargs):
167+
import matplotlib.pyplot as plt
168+
169+
fig, ax = plt.subplots()
170+
ax.plot([0, 1], [1, 0])
171+
return fig, ax
172+
173+
monkeypatch.setattr(
174+
"gen_surv.visualization.plot_survival_curve", fake_plot_survival_curve
175+
)
176+
177+
saved = []
178+
179+
def fake_savefig(path, *args, **kwargs):
180+
saved.append(path)
181+
182+
monkeypatch.setattr("matplotlib.pyplot.savefig", fake_savefig)
183+
184+
visualize(
185+
str(csv),
186+
time_col="time",
187+
status_col="status",
188+
group_col=None,
189+
output=str(tmp_path / "plot.png"),
190+
)
191+
assert saved and saved[0].endswith("plot.png")
192+
193+
194+
def test_dataset_aft_log_logistic(monkeypatch):
195+
captured = {}
196+
197+
def fake_generate(**kwargs):
198+
captured.update(kwargs)
199+
return pd.DataFrame({"time": [1], "status": [1]})
200+
201+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
202+
dataset(
203+
model="aft_log_logistic",
204+
n=1,
205+
beta=[0.1],
206+
shape=1.2,
207+
scale=2.3,
208+
output=None,
209+
)
210+
assert captured["model"] == "aft_log_logistic"
211+
assert captured["beta"] == [0.1]
212+
assert captured["shape"] == 1.2
213+
assert captured["scale"] == 2.3
214+
215+
216+
def test_dataset_competing_risks_weibull(monkeypatch):
217+
captured = {}
218+
219+
def fake_generate(**kwargs):
220+
captured.update(kwargs)
221+
return pd.DataFrame({"time": [1], "status": [1]})
222+
223+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
224+
dataset(
225+
model="competing_risks_weibull",
226+
n=1,
227+
n_risks=2,
228+
shape_params=[0.7, 1.2],
229+
scale_params=[2.0, 2.0],
230+
beta=0.3,
231+
output=None,
232+
)
233+
assert captured["n_risks"] == 2
234+
assert captured["shape_params"] == [0.7, 1.2]
235+
assert captured["scale_params"] == [2.0, 2.0]
236+
assert captured["betas"] == [0.3, 0.3]
237+
238+
239+
def test_dataset_piecewise(monkeypatch):
240+
captured = {}
241+
242+
def fake_generate(**kwargs):
243+
captured.update(kwargs)
244+
return pd.DataFrame({"time": [1], "status": [1]})
245+
246+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
247+
dataset(
248+
model="piecewise_exponential",
249+
n=1,
250+
breakpoints=[1.0],
251+
hazard_rates=[0.2, 0.3],
252+
beta=[0.4],
253+
output=None,
254+
)
255+
assert captured["breakpoints"] == [1.0]
256+
assert captured["hazard_rates"] == [0.2, 0.3]
257+
assert captured["betas"] == [0.4]

tests/test_competing_risks.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from hypothesis import given
99
from hypothesis import strategies as st
1010

11+
import gen_surv.competing_risks as cr
1112
from gen_surv.competing_risks import (
1213
cause_specific_cumulative_incidence,
1314
gen_competing_risks,
@@ -183,3 +184,33 @@ def test_reproducibility():
183184

184185
with pytest.raises(AssertionError):
185186
pd.testing.assert_frame_equal(df1, df3)
187+
188+
189+
def test_competing_risks_summary_basic():
190+
df = gen_competing_risks(n=10, n_risks=2, seed=1)
191+
summary = cr.competing_risks_summary(df)
192+
assert summary["n_subjects"] == 10
193+
assert summary["n_causes"] == 2
194+
assert set(summary["events_by_cause"]) <= {1, 2}
195+
assert "time_stats" in summary
196+
197+
198+
def test_competing_risks_summary_with_categorical():
199+
df = gen_competing_risks(n=8, n_risks=2, seed=2)
200+
df["group"] = ["A", "B"] * 4
201+
summary = cr.competing_risks_summary(df, covariate_cols=["X0", "group"])
202+
assert summary["covariate_stats"]["group"]["categories"] == 2
203+
assert "distribution" in summary["covariate_stats"]["group"]
204+
205+
206+
import matplotlib
207+
208+
matplotlib.use("Agg")
209+
210+
211+
def test_plot_cause_specific_hazards_runs():
212+
df = gen_competing_risks(n=30, n_risks=2, seed=3)
213+
fig, ax = cr.plot_cause_specific_hazards(df, time_points=np.linspace(0, 5, 5))
214+
assert hasattr(fig, "savefig")
215+
assert len(ax.get_lines()) >= 1
216+
matplotlib.pyplot.close(fig)

tests/test_piecewise.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test_piecewise_invalid_lengths():
1919
n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42
2020
)
2121

22+
2223
def test_piecewise_invalid_hazard_and_breakpoints():
2324
with pytest.raises(ValueError):
2425
gen_piecewise_exponential(
@@ -34,3 +35,38 @@ def test_piecewise_invalid_hazard_and_breakpoints():
3435
hazard_rates=[0.5, -1.0],
3536
seed=42,
3637
)
38+
39+
40+
def test_piecewise_covariate_distributions():
41+
for dist, params in [
42+
("uniform", {"low": 0.0, "high": 1.0}),
43+
("binary", {"p": 0.7}),
44+
]:
45+
df = gen_piecewise_exponential(
46+
n=5,
47+
breakpoints=[1.0],
48+
hazard_rates=[0.2, 0.4],
49+
covariate_dist=dist,
50+
covariate_params=params,
51+
seed=1,
52+
)
53+
assert len(df) == 5
54+
assert {"X0", "X1"}.issubset(df.columns)
55+
56+
57+
def test_piecewise_custom_betas_reproducible():
58+
df1 = gen_piecewise_exponential(
59+
n=5,
60+
breakpoints=[1.0],
61+
hazard_rates=[0.1, 0.2],
62+
betas=[0.5, -0.2],
63+
seed=2,
64+
)
65+
df2 = gen_piecewise_exponential(
66+
n=5,
67+
breakpoints=[1.0],
68+
hazard_rates=[0.1, 0.2],
69+
betas=[0.5, -0.2],
70+
seed=2,
71+
)
72+
pd.testing.assert_frame_equal(df1, df2)

tests/test_summary_extra.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pandas as pd
22
import pytest
3-
from gen_surv.summary import check_survival_data_quality, compare_survival_datasets
3+
from gen_surv.summary import (
4+
check_survival_data_quality,
5+
compare_survival_datasets,
6+
_print_summary,
7+
)
48
from gen_surv import generate
59

610

@@ -35,17 +39,65 @@ def test_check_survival_data_quality_no_fix():
3539

3640

3741
def test_compare_survival_datasets_basic():
38-
ds1 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0)
39-
ds2 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=1.0, covariate_range=1.0)
42+
ds1 = generate(
43+
model="cphm",
44+
n=5,
45+
model_cens="uniform",
46+
cens_par=1.0,
47+
beta=0.5,
48+
covariate_range=1.0,
49+
)
50+
ds2 = generate(
51+
model="cphm",
52+
n=5,
53+
model_cens="uniform",
54+
cens_par=1.0,
55+
beta=1.0,
56+
covariate_range=1.0,
57+
)
4058
comparison = compare_survival_datasets({"A": ds1, "B": ds2})
4159
assert set(["A", "B"]).issubset(comparison.columns)
4260
assert "n_subjects" in comparison.index
4361

4462

4563
def test_compare_survival_datasets_with_covariates_and_empty_error():
46-
ds = generate(model="cphm", n=3, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0)
64+
ds = generate(
65+
model="cphm",
66+
n=3,
67+
model_cens="uniform",
68+
cens_par=1.0,
69+
beta=0.5,
70+
covariate_range=1.0,
71+
)
4772
comparison = compare_survival_datasets({"only": ds}, covariate_cols=["X0"])
4873
assert "only" in comparison.columns
4974
assert "X0_mean" in comparison.index
5075
with pytest.raises(ValueError):
5176
compare_survival_datasets({})
77+
78+
79+
def test_check_survival_data_quality_min_and_max():
80+
df = pd.DataFrame({"time": [-1.0, 3.0], "status": [1, 1]})
81+
fixed, issues = check_survival_data_quality(
82+
df, min_time=0.0, max_time=2.0, fix_issues=True
83+
)
84+
assert (fixed["time"] <= 2.0).all()
85+
assert issues["modifications"]["values_fixed"] > 0
86+
87+
88+
def test_print_summary_with_issues(capsys):
89+
summary = {
90+
"dataset_info": {"n_subjects": 2, "n_unique_ids": 2, "n_covariates": 0},
91+
"event_info": {"n_events": 1, "n_censored": 1, "event_rate": 0.5},
92+
"time_info": {"min": 0.0, "max": 2.0, "mean": 1.0, "median": 1.0},
93+
"data_quality": {
94+
"missing_time": 0,
95+
"missing_status": 0,
96+
"negative_time": 1,
97+
"invalid_status": 0,
98+
},
99+
"covariates": {},
100+
}
101+
_print_summary(summary, "time", "status", None, [])
102+
out = capsys.readouterr().out
103+
assert "Issues detected" in out

tests/test_visualization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ def fake_import(name, *args, **kwargs):
162162

163163
def test_cli_visualize_read_error(monkeypatch, tmp_path, capsys):
164164
"""visualize handles CSV read failures gracefully."""
165-
monkeypatch.setattr("pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom")))
165+
monkeypatch.setattr(
166+
"pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom"))
167+
)
166168
csv_path = tmp_path / "x.csv"
167169
csv_path.write_text("time,status\n1,1\n")
168170
with pytest.raises(typer.Exit):
169171
visualize(str(csv_path))
170172
captured = capsys.readouterr()
171173
assert "Error loading CSV file" in captured.out
172-

0 commit comments

Comments
 (0)