|
1 | 1 | import os |
2 | 2 | import runpy |
3 | 3 | import sys |
| 4 | +import pytest |
4 | 5 |
|
5 | 6 | import pandas as pd |
6 | 7 |
|
7 | 8 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
8 | | -from gen_surv.cli import dataset |
| 9 | +from gen_surv.cli import dataset, visualize |
9 | 10 |
|
10 | 11 |
|
11 | 12 | def test_cli_dataset_stdout(monkeypatch, capsys): |
@@ -81,7 +82,9 @@ def fake_generate(**kwargs): |
81 | 82 | return pd.DataFrame({"time": [1], "status": [0]}) |
82 | 83 |
|
83 | 84 | monkeypatch.setattr("gen_surv.cli.generate", fake_generate) |
84 | | - dataset(model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None) |
| 85 | + dataset( |
| 86 | + model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None |
| 87 | + ) |
85 | 88 | assert captured["model"] == "aft_weibull" |
86 | 89 | assert captured["beta"] == [0.1, 0.2] |
87 | 90 | assert captured["shape"] == 1.1 |
@@ -146,3 +149,109 @@ def fake_generate(**kwargs): |
146 | 149 | assert captured["betas_survival"] == [0.4] |
147 | 150 | assert captured["betas_cure"] == [0.4] |
148 | 151 |
|
| 152 | + |
| 153 | +def test_dataset_invalid_model(monkeypatch): |
| 154 | + def fake_generate(**kwargs): |
| 155 | + raise ValueError("bad model") |
| 156 | + |
| 157 | + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) |
| 158 | + with pytest.raises(ValueError): |
| 159 | + dataset(model="nope", n=1, output=None) |
| 160 | + |
| 161 | + |
| 162 | +def test_cli_visualize_basic(monkeypatch, tmp_path): |
| 163 | + csv = tmp_path / "data.csv" |
| 164 | + pd.DataFrame({"time": [1, 2], "status": [1, 0]}).to_csv(csv, index=False) |
| 165 | + |
| 166 | + def fake_plot_survival_curve(**kwargs): |
| 167 | + import matplotlib.pyplot as plt |
| 168 | + |
| 169 | + fig, ax = plt.subplots() |
| 170 | + ax.plot([0, 1], [1, 0]) |
| 171 | + return fig, ax |
| 172 | + |
| 173 | + monkeypatch.setattr( |
| 174 | + "gen_surv.visualization.plot_survival_curve", fake_plot_survival_curve |
| 175 | + ) |
| 176 | + |
| 177 | + saved = [] |
| 178 | + |
| 179 | + def fake_savefig(path, *args, **kwargs): |
| 180 | + saved.append(path) |
| 181 | + |
| 182 | + monkeypatch.setattr("matplotlib.pyplot.savefig", fake_savefig) |
| 183 | + |
| 184 | + visualize( |
| 185 | + str(csv), |
| 186 | + time_col="time", |
| 187 | + status_col="status", |
| 188 | + group_col=None, |
| 189 | + output=str(tmp_path / "plot.png"), |
| 190 | + ) |
| 191 | + assert saved and saved[0].endswith("plot.png") |
| 192 | + |
| 193 | + |
| 194 | +def test_dataset_aft_log_logistic(monkeypatch): |
| 195 | + captured = {} |
| 196 | + |
| 197 | + def fake_generate(**kwargs): |
| 198 | + captured.update(kwargs) |
| 199 | + return pd.DataFrame({"time": [1], "status": [1]}) |
| 200 | + |
| 201 | + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) |
| 202 | + dataset( |
| 203 | + model="aft_log_logistic", |
| 204 | + n=1, |
| 205 | + beta=[0.1], |
| 206 | + shape=1.2, |
| 207 | + scale=2.3, |
| 208 | + output=None, |
| 209 | + ) |
| 210 | + assert captured["model"] == "aft_log_logistic" |
| 211 | + assert captured["beta"] == [0.1] |
| 212 | + assert captured["shape"] == 1.2 |
| 213 | + assert captured["scale"] == 2.3 |
| 214 | + |
| 215 | + |
| 216 | +def test_dataset_competing_risks_weibull(monkeypatch): |
| 217 | + captured = {} |
| 218 | + |
| 219 | + def fake_generate(**kwargs): |
| 220 | + captured.update(kwargs) |
| 221 | + return pd.DataFrame({"time": [1], "status": [1]}) |
| 222 | + |
| 223 | + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) |
| 224 | + dataset( |
| 225 | + model="competing_risks_weibull", |
| 226 | + n=1, |
| 227 | + n_risks=2, |
| 228 | + shape_params=[0.7, 1.2], |
| 229 | + scale_params=[2.0, 2.0], |
| 230 | + beta=0.3, |
| 231 | + output=None, |
| 232 | + ) |
| 233 | + assert captured["n_risks"] == 2 |
| 234 | + assert captured["shape_params"] == [0.7, 1.2] |
| 235 | + assert captured["scale_params"] == [2.0, 2.0] |
| 236 | + assert captured["betas"] == [0.3, 0.3] |
| 237 | + |
| 238 | + |
| 239 | +def test_dataset_piecewise(monkeypatch): |
| 240 | + captured = {} |
| 241 | + |
| 242 | + def fake_generate(**kwargs): |
| 243 | + captured.update(kwargs) |
| 244 | + return pd.DataFrame({"time": [1], "status": [1]}) |
| 245 | + |
| 246 | + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) |
| 247 | + dataset( |
| 248 | + model="piecewise_exponential", |
| 249 | + n=1, |
| 250 | + breakpoints=[1.0], |
| 251 | + hazard_rates=[0.2, 0.3], |
| 252 | + beta=[0.4], |
| 253 | + output=None, |
| 254 | + ) |
| 255 | + assert captured["breakpoints"] == [1.0] |
| 256 | + assert captured["hazard_rates"] == [0.2, 0.3] |
| 257 | + assert captured["betas"] == [0.4] |
0 commit comments