Skip to content

Commit cca0cb8

Browse files
Feat/add unit tests (#88)
* tests: remove regression baselines * tests: add deterministic generation and visualize CLI integration
1 parent 9111394 commit cca0cb8

10 files changed

+342
-2
lines changed

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ python = ">=3.10,<3.13"
2828
numpy = "^1.26"
2929
pandas = "^2.2.3"
3030
typer = "^0.12.3"
31+
click = ">=8.1,<8.2"
3132
matplotlib = "~3.8"
3233
lifelines = "^0.30"
3334
pyarrow = "^14"
@@ -36,10 +37,11 @@ pyreadr = "^0.5"
3637
[tool.poetry.group.dev.dependencies]
3738
pytest = "^8.3.5"
3839
pytest-cov = "^6.1.1"
40+
pytest-benchmark = "^4.0"
3941
python-semantic-release = "^9.21.0"
4042
mypy = "^1.15.0"
4143
invoke = "^2.2.0"
42-
hypothesis = "^6.98"
44+
hypothesis = "^6.108"
4345
tomli = "^2.2.1"
4446
black = "^24.1.0"
4547
isort = "^5.13.2"
@@ -51,6 +53,7 @@ pre-commit = "^3.8"
5153
dev = [
5254
"pytest",
5355
"pytest-cov",
56+
"pytest-benchmark",
5457
"python-semantic-release",
5558
"mypy",
5659
"invoke",
@@ -72,6 +75,11 @@ sphinx-copybutton = "^0.5.2"
7275
sphinx-design = "^0.5.0"
7376
linkify-it-py = ">=2.0"
7477

78+
[tool.pytest.ini_options]
79+
markers = [
80+
"slow: slow performance tests",
81+
]
82+
7583
[tool.poetry.scripts]
7684
gen_surv = "gen_surv.cli:app"
7785

tests/conftest.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import Callable
5+
6+
import numpy as np
7+
import pandas as pd
8+
import pytest
9+
10+
BASELINE_DIR = Path(__file__).parent / "baselines"
11+
BASELINE_DIR.mkdir(exist_ok=True)
12+
13+
14+
def pytest_addoption(parser: pytest.Parser) -> None:
15+
parser.addoption(
16+
"--update-baselines",
17+
action="store_true",
18+
default=False,
19+
help="Refresh stored baselines for regression tests.",
20+
)
21+
22+
23+
@pytest.fixture(scope="session")
24+
def rng() -> np.random.Generator:
25+
return np.random.default_rng(seed=42)
26+
27+
28+
@pytest.fixture(scope="session")
29+
def save_baseline() -> Callable[[pd.DataFrame, str], None]:
30+
def _save(df: pd.DataFrame, name: str) -> None:
31+
(BASELINE_DIR / f"{name}.parquet").write_bytes(df.to_parquet(index=False))
32+
33+
return _save
34+
35+
36+
@pytest.fixture(scope="session")
37+
def load_baseline() -> Callable[[str], pd.DataFrame]:
38+
def _load(name: str) -> pd.DataFrame:
39+
path = BASELINE_DIR / f"{name}.parquet"
40+
if not path.exists():
41+
pytest.skip(
42+
f"Missing baseline {path}; run with --update-baselines to refresh."
43+
)
44+
return pd.read_parquet(path)
45+
46+
return _load
47+
48+
49+
def assert_frame_numeric_equal(
50+
got: pd.DataFrame,
51+
expected: pd.DataFrame,
52+
*,
53+
rtol: float = 1e-6,
54+
atol: float = 1e-8,
55+
) -> None:
56+
assert list(got.columns) == list(expected.columns), "Column order/name changed."
57+
assert got.shape == expected.shape, "Shape changed."
58+
for col in got.columns:
59+
g = pd.to_numeric(got[col], errors="coerce")
60+
e = pd.to_numeric(expected[col], errors="coerce")
61+
if g.notna().all() and e.notna().all():
62+
np.testing.assert_allclose(g.to_numpy(), e.to_numpy(), rtol=rtol, atol=atol)
63+
else:
64+
assert (
65+
got[col].astype(str).values == expected[col].astype(str).values
66+
).all(), f"Mismatch in column {col!r}"

tests/test_api_contract.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
5+
import gen_surv
6+
7+
8+
def test_generate_signature_stable() -> None:
9+
sig = inspect.signature(gen_surv.generate)
10+
assert "model:" in str(sig) and "**kwargs" in str(sig)

tests/test_cli_snapshot.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import annotations
2+
3+
import subprocess
4+
import sys
5+
6+
7+
def run_cli(args: list[str]) -> subprocess.CompletedProcess[str]:
8+
return subprocess.run(
9+
[sys.executable, "-m", "gen_surv", *args],
10+
text=True,
11+
capture_output=True,
12+
)
13+
14+
15+
def test_cli_help_shows_usage() -> None:
16+
cp = run_cli(["--help"])
17+
assert cp.returncode == 0, cp.stderr
18+
assert "Generate synthetic survival datasets." in cp.stdout
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from typer.testing import CliRunner
4+
5+
from gen_surv import generate
6+
from gen_surv.cli import app
7+
8+
9+
def test_visualize_cli_generates_plot(tmp_path) -> None:
10+
csv_path = tmp_path / "data.csv"
11+
plot_path = tmp_path / "plot.png"
12+
df = generate(
13+
model="cphm",
14+
n=10,
15+
beta=0.5,
16+
covariate_range=1.0,
17+
model_cens="uniform",
18+
cens_par=0.7,
19+
seed=1234,
20+
)
21+
df.to_csv(csv_path, index=False)
22+
runner = CliRunner()
23+
result = runner.invoke(
24+
app, ["visualize", str(csv_path), "--output", str(plot_path)]
25+
)
26+
assert result.exit_code == 0
27+
assert plot_path.exists()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import annotations
2+
3+
from gen_surv import generate
4+
from tests.conftest import assert_frame_numeric_equal
5+
6+
7+
def test_generate_reproducible_with_seed() -> None:
8+
cfg = dict(
9+
model="cphm",
10+
n=32,
11+
beta=0.5,
12+
covariate_range=2.0,
13+
model_cens="uniform",
14+
cens_par=0.7,
15+
seed=1234,
16+
)
17+
df1 = generate(**cfg)
18+
df2 = generate(**cfg)
19+
assert_frame_numeric_equal(df1, df2)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from gen_surv import generate
5+
from gen_surv.export import export_dataset
6+
from gen_surv.integration import to_sksurv
7+
8+
9+
def test_generate_export_roundtrip(tmp_path):
10+
"""Integration test for generate and export_dataset."""
11+
df = generate(
12+
model="cphm",
13+
n=10,
14+
model_cens="uniform",
15+
cens_par=1.0,
16+
beta=0.5,
17+
covariate_range=1.0,
18+
seed=42,
19+
)
20+
out = tmp_path / "data.json"
21+
export_dataset(df, out)
22+
loaded = pd.read_json(out, orient="table")
23+
pd.testing.assert_frame_equal(df, loaded)
24+
25+
26+
def test_generate_export_to_sksurv_roundtrip(tmp_path):
27+
"""Full pipeline from generation to scikit-survival array."""
28+
pytest.importorskip("sksurv.util")
29+
df = generate(
30+
model="cphm",
31+
n=8,
32+
model_cens="uniform",
33+
cens_par=1.0,
34+
beta=0.5,
35+
covariate_range=1.0,
36+
seed=0,
37+
)
38+
out = tmp_path / "data.json"
39+
export_dataset(df, out)
40+
loaded = pd.read_json(out, orient="table")
41+
arr = to_sksurv(loaded)
42+
assert arr.shape[0] == 8

tests/test_generate_regression.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict
4+
5+
import pandas as pd
6+
import pytest
7+
8+
from gen_surv import generate
9+
10+
MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
11+
"cphm": dict(
12+
model="cphm",
13+
n=256,
14+
beta=0.5,
15+
covariate_range=2.0,
16+
model_cens="uniform",
17+
cens_par=0.7,
18+
seed=1234,
19+
),
20+
"aft_ln": dict(
21+
model="aft_ln",
22+
n=256,
23+
beta=[0.5],
24+
sigma=0.8,
25+
model_cens="uniform",
26+
cens_par=0.8,
27+
seed=1234,
28+
),
29+
"aft_log_logistic": dict(
30+
model="aft_log_logistic",
31+
n=256,
32+
beta=[0.5],
33+
shape=1.3,
34+
scale=1.7,
35+
model_cens="uniform",
36+
cens_par=0.8,
37+
seed=1234,
38+
),
39+
"aft_weibull": dict(
40+
model="aft_weibull",
41+
n=256,
42+
beta=[0.5],
43+
shape=1.4,
44+
scale=1.1,
45+
model_cens="uniform",
46+
cens_par=0.8,
47+
seed=1234,
48+
),
49+
}
50+
51+
52+
@pytest.mark.parametrize("model_key", sorted(MODEL_CONFIGS.keys()))
53+
def test_generate_matches_baseline(
54+
model_key: str,
55+
request: pytest.FixtureRequest,
56+
load_baseline,
57+
save_baseline,
58+
) -> None:
59+
cfg = MODEL_CONFIGS[model_key]
60+
df: pd.DataFrame = generate(**cfg)
61+
assert "time" in df.columns and (
62+
"event" in df.columns or "status" in df.columns
63+
), "Missing core survival columns."
64+
baseline_name = f"gen_{model_key}"
65+
if request.config.getoption("--update-baselines"):
66+
save_baseline(df, baseline_name)
67+
pytest.skip(
68+
f"Baseline {baseline_name} updated; re-run without --update-baselines."
69+
)
70+
expected = load_baseline(baseline_name)
71+
from conftest import assert_frame_numeric_equal
72+
73+
assert_frame_numeric_equal(df[expected.columns], expected)

tests/test_integration_sksurv.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,68 @@
1+
import sys
2+
import types
3+
14
import pandas as pd
25
import pytest
36

47
from gen_surv.integration import to_sksurv
8+
from gen_surv.interface import generate
59

610

711
def test_to_sksurv():
8-
# Optional integration test; skipped when scikit-survival is not installed.
12+
"""Basic conversion with default column names."""
913
pytest.importorskip("sksurv.util")
1014
df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]})
1115
arr = to_sksurv(df)
1216
assert arr.dtype.names == ("status", "time")
1317
assert arr.shape[0] == 2
18+
19+
20+
def test_to_sksurv_custom_columns():
21+
"""Unit test for custom time/event column names."""
22+
pytest.importorskip("sksurv.util")
23+
df = pd.DataFrame({"T": [1.0, 2.0], "E": [1, 0]})
24+
arr = to_sksurv(df, time_col="T", event_col="E")
25+
assert arr.dtype.names == ("E", "T")
26+
27+
28+
def test_to_sksurv_missing_dependency(monkeypatch):
29+
"""Regression test ensuring a helpful ImportError is raised."""
30+
fake_mod = types.ModuleType("sksurv")
31+
monkeypatch.setitem(sys.modules, "sksurv", fake_mod)
32+
monkeypatch.delitem(sys.modules, "sksurv.util", raising=False)
33+
df = pd.DataFrame({"time": [1.0], "status": [1]})
34+
with pytest.raises(ImportError, match="scikit-survival is required"):
35+
to_sksurv(df)
36+
37+
38+
def test_to_sksurv_missing_columns():
39+
"""Regression test: missing required columns should raise KeyError."""
40+
pytest.importorskip("sksurv.util")
41+
df = pd.DataFrame({"status": [1, 0]})
42+
with pytest.raises(KeyError):
43+
to_sksurv(df)
44+
45+
46+
def test_to_sksurv_empty_dataframe():
47+
"""Unit test for handling empty DataFrames."""
48+
pytest.importorskip("sksurv.util")
49+
df = pd.DataFrame({"time": [], "status": []})
50+
arr = to_sksurv(df)
51+
assert arr.shape == (0,)
52+
assert arr.dtype.names == ("status", "time")
53+
54+
55+
def test_generate_to_sksurv_pipeline():
56+
"""Integration test covering generation and conversion."""
57+
pytest.importorskip("sksurv.util")
58+
df = generate(
59+
model="cphm",
60+
n=5,
61+
model_cens="uniform",
62+
cens_par=1.0,
63+
beta=0.5,
64+
covariate_range=1.0,
65+
seed=0,
66+
)
67+
arr = to_sksurv(df)
68+
assert arr.shape[0] == 5

tests/test_perf_smoke.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from gen_surv import generate
6+
7+
8+
@pytest.mark.slow
9+
def test_generate_perf_smoke(benchmark) -> None:
10+
def _run():
11+
return generate(
12+
model="cphm",
13+
n=50_000,
14+
beta=0.5,
15+
covariate_range=2.0,
16+
model_cens="uniform",
17+
cens_par=0.7,
18+
seed=123,
19+
)
20+
21+
df = benchmark(_run)
22+
assert len(df) == 50_000

0 commit comments

Comments
 (0)