Skip to content

Commit 7d99262

Browse files
committed
Add Typer dependency and expand tests to improve coverage
1 parent e75be6e commit 7d99262

File tree

5 files changed

+84
-0
lines changed

5 files changed

+84
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ numpy = "^1.26"
1616
pandas = "^2.2.3"
1717
pytest-cov = "^6.1.1"
1818
invoke = "^2.2.0"
19+
typer = "^0.12.3"
1920

2021
[tool.poetry.group.dev.dependencies]
2122
pytest = "^8.3.5"

tests/test_bivariate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
from gen_surv.bivariate import sample_bivariate_distribution
3+
import pytest
4+
5+
6+
def test_sample_bivariate_exponential_shape():
7+
"""Exponential distribution should return an array of shape (n, 2)."""
8+
result = sample_bivariate_distribution(5, "exponential", 0.0, [1.0, 1.0])
9+
assert isinstance(result, np.ndarray)
10+
assert result.shape == (5, 2)
11+
12+
13+
def test_sample_bivariate_invalid_dist():
14+
"""Unsupported distributions should raise ValueError."""
15+
with pytest.raises(ValueError):
16+
sample_bivariate_distribution(10, "invalid", 0.0, [1, 1])

tests/test_cli.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pandas as pd
2+
from gen_surv.cli import dataset
3+
import runpy
4+
5+
6+
def test_cli_dataset_stdout(monkeypatch, capsys):
7+
"""Dataset command prints CSV to stdout when no output file is given."""
8+
9+
def fake_generate(model: str, n: int):
10+
return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]})
11+
12+
# Patch the generate function used in the CLI to avoid heavy computation.
13+
monkeypatch.setattr("gen_surv.cli.generate", fake_generate)
14+
# Call the command function directly to sidestep Click argument parsing
15+
dataset(model="cphm", n=1, output=None)
16+
captured = capsys.readouterr()
17+
assert "time,status,X0,X1" in captured.out
18+
19+
20+
def test_main_entry_point(monkeypatch):
21+
"""Running the module as a script should invoke the CLI app."""
22+
23+
called = []
24+
25+
def fake_app():
26+
called.append(True)
27+
28+
# Patch the CLI app before the module is executed
29+
monkeypatch.setattr("gen_surv.cli.app", fake_app)
30+
monkeypatch.setattr("sys.argv", ["gen_surv", "dataset", "cphm"])
31+
runpy.run_module("gen_surv.__main__", run_name="__main__")
32+
assert called

tests/test_validate.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
import gen_surv.validate as v
3+
4+
5+
def test_validate_gen_cphm_inputs_valid():
6+
"""Ensure valid inputs pass without raising an exception."""
7+
v.validate_gen_cphm_inputs(1, "uniform", 0.5, 1.0)
8+
9+
10+
@pytest.mark.parametrize(
11+
"n, model_cens, cens_par, covar",
12+
[
13+
(0, "uniform", 0.5, 1.0),
14+
(1, "bad", 0.5, 1.0),
15+
(1, "uniform", -1.0, 1.0),
16+
(1, "uniform", 0.5, -1.0),
17+
],
18+
)
19+
def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covar):
20+
"""Invalid parameter combinations should raise ValueError."""
21+
with pytest.raises(ValueError):
22+
v.validate_gen_cphm_inputs(n, model_cens, cens_par, covar)
23+
24+
25+
def test_validate_dg_biv_inputs_invalid():
26+
"""Invalid distribution names should raise an error."""
27+
with pytest.raises(ValueError):
28+
v.validate_dg_biv_inputs(10, "normal", 0.1, [1, 1])

tests/test_version.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from importlib.metadata import version
2+
from gen_surv import __version__
3+
4+
5+
def test_package_version_matches_metadata():
6+
"""The exported __version__ should match package metadata."""
7+
assert __version__ == version("gen_surv")

0 commit comments

Comments
 (0)