Skip to content

Commit 9036686

Browse files
Add visualization CLI tests (#47)
1 parent 1671a02 commit 9036686

File tree

2 files changed

+106
-1
lines changed

2 files changed

+106
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pandas = "^2.2.3"
3030
typer = "^0.12.3"
3131
matplotlib = "^3.10"
3232
lifelines = "^0.30"
33+
pyarrow = "^14"
3334

3435
[tool.poetry.group.dev.dependencies]
3536
pytest = "^8.3.5"

tests/test_visualization.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1+
import pandas as pd
2+
import pytest
3+
import typer
4+
15
from gen_surv import generate
2-
from gen_surv.visualization import plot_survival_curve
6+
from gen_surv.cli import visualize
7+
from gen_surv.visualization import (
8+
describe_survival,
9+
plot_covariate_effect,
10+
plot_hazard_comparison,
11+
plot_survival_curve,
12+
)
313

414

515
def test_plot_survival_curve_runs():
@@ -14,3 +24,97 @@ def test_plot_survival_curve_runs():
1424
fig, ax = plot_survival_curve(df)
1525
assert fig is not None
1626
assert ax is not None
27+
28+
29+
def test_plot_hazard_comparison_runs():
30+
df1 = generate(
31+
model="cphm",
32+
n=5,
33+
model_cens="uniform",
34+
cens_par=1.0,
35+
beta=0.5,
36+
covariate_range=1.0,
37+
)
38+
df2 = generate(
39+
model="aft_weibull",
40+
n=5,
41+
beta=[0.5],
42+
shape=1.5,
43+
scale=2.0,
44+
model_cens="uniform",
45+
cens_par=1.0,
46+
)
47+
models = {"cphm": df1, "aft_weibull": df2}
48+
fig, ax = plot_hazard_comparison(models)
49+
assert fig is not None
50+
assert ax is not None
51+
52+
53+
def test_plot_covariate_effect_runs():
54+
df = generate(
55+
model="cphm",
56+
n=10,
57+
model_cens="uniform",
58+
cens_par=1.0,
59+
beta=0.5,
60+
covariate_range=2.0,
61+
)
62+
fig, ax = plot_covariate_effect(df, covariate_col="X0", n_groups=2)
63+
assert fig is not None
64+
assert ax is not None
65+
66+
67+
def test_describe_survival_summary():
68+
df = generate(
69+
model="cphm",
70+
n=10,
71+
model_cens="uniform",
72+
cens_par=1.0,
73+
beta=0.5,
74+
covariate_range=2.0,
75+
)
76+
summary = describe_survival(df)
77+
expected_metrics = [
78+
"Total Observations",
79+
"Number of Events",
80+
"Number Censored",
81+
"Event Rate",
82+
"Median Survival Time",
83+
"Min Time",
84+
"Max Time",
85+
"Mean Time",
86+
]
87+
assert list(summary["Metric"]) == expected_metrics
88+
assert summary.shape[0] == len(expected_metrics)
89+
90+
91+
def test_cli_visualize(tmp_path, capsys):
92+
df = pd.DataFrame({"time": [1, 2, 3], "status": [1, 0, 1]})
93+
csv_path = tmp_path / "d.csv"
94+
df.to_csv(csv_path, index=False)
95+
out_file = tmp_path / "out.png"
96+
visualize(
97+
str(csv_path),
98+
time_col="time",
99+
status_col="status",
100+
group_col=None,
101+
output=str(out_file),
102+
)
103+
assert out_file.exists()
104+
captured = capsys.readouterr()
105+
assert "Plot saved to" in captured.out
106+
107+
108+
def test_cli_visualize_missing_column(tmp_path, capsys):
109+
df = pd.DataFrame({"time": [1, 2], "event": [1, 0]})
110+
csv_path = tmp_path / "bad.csv"
111+
df.to_csv(csv_path, index=False)
112+
with pytest.raises(typer.Exit):
113+
visualize(
114+
str(csv_path),
115+
time_col="time",
116+
status_col="status",
117+
group_col=None,
118+
)
119+
captured = capsys.readouterr()
120+
assert "Status column 'status' not found in data" in captured.out

0 commit comments

Comments
 (0)