Skip to content

Commit 533f118

Browse files
Expand validate tests (#49)
1 parent 9036686 commit 533f118

File tree

3 files changed

+273
-0
lines changed

3 files changed

+273
-0
lines changed

tests/test_cli.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,96 @@ def fake_generate(model: str, n: int):
5353
assert out_file.exists()
5454
content = out_file.read_text()
5555
assert "time,status,X0,X1" in content
56+
57+
58+
def test_dataset_fallback(monkeypatch):
59+
"""If generate fails with additional kwargs, dataset retries with minimal args."""
60+
calls = []
61+
62+
def fake_generate(**kwargs):
63+
calls.append(kwargs)
64+
if len(calls) == 1:
65+
raise TypeError("bad args")
66+
return pd.DataFrame({"time": [0], "status": [1]})
67+
68+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
69+
dataset(model="cphm", n=2, output=None)
70+
# first call has many parameters, second only model and n
71+
assert calls[-1] == {"model": "cphm", "n": 2}
72+
assert len(calls) == 2
73+
74+
75+
def test_dataset_weibull_parameters(monkeypatch):
76+
"""Parameters for aft_weibull model are forwarded correctly."""
77+
captured = {}
78+
79+
def fake_generate(**kwargs):
80+
captured.update(kwargs)
81+
return pd.DataFrame({"time": [1], "status": [0]})
82+
83+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
84+
dataset(model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None)
85+
assert captured["model"] == "aft_weibull"
86+
assert captured["beta"] == [0.1, 0.2]
87+
assert captured["shape"] == 1.1
88+
assert captured["scale"] == 2.2
89+
90+
91+
def test_dataset_aft_ln(monkeypatch):
92+
"""aft_ln model should forward beta list and sigma."""
93+
captured = {}
94+
95+
def fake_generate(**kwargs):
96+
captured.update(kwargs)
97+
return pd.DataFrame({"time": [1], "status": [1]})
98+
99+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
100+
dataset(model="aft_ln", n=1, beta=[0.3, 0.4], sigma=1.2, output=None)
101+
assert captured["beta"] == [0.3, 0.4]
102+
assert captured["sigma"] == 1.2
103+
104+
105+
def test_dataset_competing_risks(monkeypatch):
106+
"""competing_risks expands betas and passes hazards."""
107+
captured = {}
108+
109+
def fake_generate(**kwargs):
110+
captured.update(kwargs)
111+
return pd.DataFrame({"time": [1], "status": [1]})
112+
113+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
114+
dataset(
115+
model="competing_risks",
116+
n=1,
117+
n_risks=2,
118+
baseline_hazards=[0.1, 0.2],
119+
beta=0.5,
120+
output=None,
121+
)
122+
assert captured["n_risks"] == 2
123+
assert captured["baseline_hazards"] == [0.1, 0.2]
124+
assert captured["betas"] == [0.5, 0.5]
125+
126+
127+
def test_dataset_mixture_cure(monkeypatch):
128+
"""mixture_cure passes cure and baseline parameters."""
129+
captured = {}
130+
131+
def fake_generate(**kwargs):
132+
captured.update(kwargs)
133+
return pd.DataFrame({"time": [1], "status": [1]})
134+
135+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
136+
dataset(
137+
model="mixture_cure",
138+
n=1,
139+
cure_fraction=0.2,
140+
baseline_hazard=0.1,
141+
beta=[0.4],
142+
output=None,
143+
)
144+
assert captured["cure_fraction"] == 0.2
145+
assert captured["baseline_hazard"] == 0.1
146+
assert captured["betas_survival"] == [0.4]
147+
assert captured["betas_cure"] == [0.4]
148+

tests/test_validate.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,31 @@ def test_validate_gen_cmm_inputs_invalid_beta_length():
4242
)
4343

4444

45+
@pytest.mark.parametrize(
46+
"n, model_cens, cens_par, cov_range, rate",
47+
[
48+
(0, "uniform", 0.5, 1.0, [0.1] * 6),
49+
(1, "bad", 0.5, 1.0, [0.1] * 6),
50+
(1, "uniform", 0.0, 1.0, [0.1] * 6),
51+
(1, "uniform", 0.5, 0.0, [0.1] * 6),
52+
(1, "uniform", 0.5, 1.0, [0.1] * 3),
53+
],
54+
)
55+
def test_validate_gen_cmm_inputs_other_invalid(
56+
n, model_cens, cens_par, cov_range, rate
57+
):
58+
with pytest.raises(ValueError):
59+
v.validate_gen_cmm_inputs(
60+
n, model_cens, cens_par, [0.1, 0.2, 0.3], cov_range, rate
61+
)
62+
63+
64+
def test_validate_gen_cmm_inputs_valid():
65+
v.validate_gen_cmm_inputs(
66+
1, "uniform", 1.0, [0.1, 0.2, 0.3], covariate_range=1.0, rate=[0.1] * 6
67+
)
68+
69+
4570
def test_validate_gen_tdcm_inputs_invalid_lambda():
4671
"""Lambda <= 0 should raise a ValueError."""
4772
with pytest.raises(ValueError):
@@ -57,6 +82,44 @@ def test_validate_gen_tdcm_inputs_invalid_lambda():
5782
)
5883

5984

85+
@pytest.mark.parametrize(
86+
"dist,corr,dist_par",
87+
[
88+
("bad", 0.5, [1, 2]),
89+
("weibull", 0.0, [1, 2, 3, 4]),
90+
("weibull", 0.5, [1, 2, -1, 2]),
91+
("weibull", 0.5, [1, 2, 3]),
92+
("exponential", 2.0, [1, 1]),
93+
("exponential", 0.5, [1]),
94+
],
95+
)
96+
def test_validate_gen_tdcm_inputs_invalid_dist(dist, corr, dist_par):
97+
with pytest.raises(ValueError):
98+
v.validate_gen_tdcm_inputs(
99+
1,
100+
dist,
101+
corr,
102+
dist_par,
103+
"uniform",
104+
1.0,
105+
beta=[0.1, 0.2, 0.3],
106+
lam=1.0,
107+
)
108+
109+
110+
def test_validate_gen_tdcm_inputs_valid():
111+
v.validate_gen_tdcm_inputs(
112+
1,
113+
"weibull",
114+
0.5,
115+
[1, 1, 1, 1],
116+
"uniform",
117+
1.0,
118+
beta=[0.1, 0.2, 0.3],
119+
lam=1.0,
120+
)
121+
122+
60123
def test_validate_gen_aft_log_normal_inputs_valid():
61124
"""Valid parameters should not raise an error for AFT log-normal."""
62125
v.validate_gen_aft_log_normal_inputs(
@@ -68,19 +131,84 @@ def test_validate_gen_aft_log_normal_inputs_valid():
68131
)
69132

70133

134+
@pytest.mark.parametrize(
135+
"n,beta,sigma,model_cens,cens_par",
136+
[
137+
(0, [0.1], 1.0, "uniform", 1.0),
138+
(1, "bad", 1.0, "uniform", 1.0),
139+
(1, [0.1], 0.0, "uniform", 1.0),
140+
(1, [0.1], 1.0, "bad", 1.0),
141+
(1, [0.1], 1.0, "uniform", 0.0),
142+
],
143+
)
144+
def test_validate_gen_aft_log_normal_inputs_invalid(
145+
n, beta, sigma, model_cens, cens_par
146+
):
147+
with pytest.raises(ValueError):
148+
v.validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par)
149+
150+
71151
def test_validate_dg_biv_inputs_valid_weibull():
72152
"""Valid parameters for a Weibull distribution should pass."""
73153
v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0])
74154

75155

156+
def test_validate_dg_biv_inputs_invalid_corr_and_params():
157+
with pytest.raises(ValueError):
158+
v.validate_dg_biv_inputs(1, "exponential", -2.0, [1.0, 1.0])
159+
with pytest.raises(ValueError):
160+
v.validate_dg_biv_inputs(1, "exponential", 0.5, [1.0])
161+
with pytest.raises(ValueError):
162+
v.validate_dg_biv_inputs(1, "weibull", 0.5, [1.0, 1.0])
163+
164+
76165
def test_validate_gen_aft_weibull_inputs_and_log_logistic():
77166
with pytest.raises(ValueError):
78167
v.validate_gen_aft_weibull_inputs(0, [0.1], 1.0, 1.0, "uniform", 1.0)
79168
with pytest.raises(ValueError):
80169
v.validate_gen_aft_log_logistic_inputs(1, [0.1], -1.0, 1.0, "uniform", 1.0)
81170

82171

172+
@pytest.mark.parametrize(
173+
"shape,scale",
174+
[(-1.0, 1.0), (1.0, -1.0)],
175+
)
176+
def test_validate_gen_aft_weibull_invalid_params(shape, scale):
177+
with pytest.raises(ValueError):
178+
v.validate_gen_aft_weibull_inputs(1, [0.1], shape, scale, "uniform", 1.0)
179+
180+
181+
def test_validate_gen_aft_weibull_valid():
182+
v.validate_gen_aft_weibull_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0)
183+
184+
185+
def test_validate_gen_aft_log_logistic_valid():
186+
v.validate_gen_aft_log_logistic_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0)
187+
188+
83189
def test_validate_competing_risks_inputs():
84190
with pytest.raises(ValueError):
85191
v.validate_competing_risks_inputs(1, 2, [0.1], None, "uniform", 1.0)
86192
v.validate_competing_risks_inputs(1, 1, [0.5], [[0.1]], "uniform", 0.5)
193+
194+
195+
@pytest.mark.parametrize(
196+
"n,model_cens,cens_par,beta,cov_range,rate",
197+
[
198+
(0, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]),
199+
(1, "bad", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]),
200+
(1, "uniform", 0.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]),
201+
(1, "uniform", 1.0, [0.1, 0.2], 1.0, [0.1, 0.2, 0.3]),
202+
(1, "uniform", 1.0, [0.1, 0.2, 0.3], 0.0, [0.1, 0.2, 0.3]),
203+
(1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1]),
204+
],
205+
)
206+
def test_validate_gen_thmm_inputs_invalid(
207+
n, model_cens, cens_par, beta, cov_range, rate
208+
):
209+
with pytest.raises(ValueError):
210+
v.validate_gen_thmm_inputs(n, model_cens, cens_par, beta, cov_range, rate)
211+
212+
213+
def test_validate_gen_thmm_inputs_valid():
214+
v.validate_gen_thmm_inputs(1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3])

tests/test_visualization.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,55 @@ def test_cli_visualize_missing_column(tmp_path, capsys):
118118
)
119119
captured = capsys.readouterr()
120120
assert "Status column 'status' not found in data" in captured.out
121+
122+
123+
def test_cli_visualize_missing_time(tmp_path, capsys):
124+
df = pd.DataFrame({"t": [1, 2], "status": [1, 0]})
125+
path = tmp_path / "d.csv"
126+
df.to_csv(path, index=False)
127+
with pytest.raises(typer.Exit):
128+
visualize(str(path), time_col="time", status_col="status")
129+
captured = capsys.readouterr()
130+
assert "Time column 'time' not found in data" in captured.out
131+
132+
133+
def test_cli_visualize_missing_group(tmp_path, capsys):
134+
df = pd.DataFrame({"time": [1], "status": [1], "x": [0]})
135+
path = tmp_path / "d2.csv"
136+
df.to_csv(path, index=False)
137+
with pytest.raises(typer.Exit):
138+
visualize(str(path), time_col="time", status_col="status", group_col="group")
139+
captured = capsys.readouterr()
140+
assert "Group column 'group' not found in data" in captured.out
141+
142+
143+
def test_cli_visualize_import_error(monkeypatch, tmp_path, capsys):
144+
"""visualize exits when matplotlib is missing."""
145+
import builtins
146+
147+
real_import = builtins.__import__
148+
149+
def fake_import(name, *args, **kwargs):
150+
if name.startswith("matplotlib"): # simulate missing dependency
151+
raise ImportError("no matplot")
152+
return real_import(name, *args, **kwargs)
153+
154+
monkeypatch.setattr(builtins, "__import__", fake_import)
155+
csv_path = tmp_path / "d.csv"
156+
pd.DataFrame({"time": [1], "status": [1]}).to_csv(csv_path, index=False)
157+
with pytest.raises(typer.Exit):
158+
visualize(str(csv_path))
159+
captured = capsys.readouterr()
160+
assert "Visualization requires matplotlib" in captured.out
161+
162+
163+
def test_cli_visualize_read_error(monkeypatch, tmp_path, capsys):
164+
"""visualize handles CSV read failures gracefully."""
165+
monkeypatch.setattr("pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom")))
166+
csv_path = tmp_path / "x.csv"
167+
csv_path.write_text("time,status\n1,1\n")
168+
with pytest.raises(typer.Exit):
169+
visualize(str(csv_path))
170+
captured = capsys.readouterr()
171+
assert "Error loading CSV file" in captured.out
172+

0 commit comments

Comments
 (0)