Skip to content

Commit 1671a02

Browse files
Add extensive summary tests (#45)
1 parent c3571dd commit 1671a02

File tree

6 files changed

+209
-0
lines changed

6 files changed

+209
-0
lines changed

tests/test_export.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import pandas as pd
4+
import pytest
45

56
from gen_surv import export_dataset, generate
67

@@ -35,3 +36,22 @@ def test_export_dataset_json(tmp_path):
3536
assert out_file.exists()
3637
loaded = pd.read_json(out_file, orient="table")
3738
pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded)
39+
40+
41+
def test_export_dataset_feather_and_invalid(tmp_path):
42+
df = generate(
43+
model="cphm",
44+
n=5,
45+
model_cens="uniform",
46+
cens_par=1.0,
47+
beta=0.5,
48+
covariate_range=1.0,
49+
)
50+
feather_file = tmp_path / "data.feather"
51+
export_dataset(df, str(feather_file))
52+
assert feather_file.exists()
53+
loaded = pd.read_feather(feather_file)
54+
pd.testing.assert_frame_equal(df.reset_index(drop=True), loaded)
55+
56+
with pytest.raises(ValueError):
57+
export_dataset(df, str(tmp_path / "data.txt"), fmt="txt")

tests/test_piecewise.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,19 @@ def test_piecewise_invalid_lengths():
1818
gen_piecewise_exponential(
1919
n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42
2020
)
21+
22+
def test_piecewise_invalid_hazard_and_breakpoints():
23+
with pytest.raises(ValueError):
24+
gen_piecewise_exponential(
25+
n=5,
26+
breakpoints=[2.0, 1.0],
27+
hazard_rates=[0.5, 1.0, 1.5],
28+
seed=42,
29+
)
30+
with pytest.raises(ValueError):
31+
gen_piecewise_exponential(
32+
n=5,
33+
breakpoints=[1.0],
34+
hazard_rates=[0.5, -1.0],
35+
seed=42,
36+
)

tests/test_piecewise_functions.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
from gen_surv.piecewise import piecewise_hazard_function, piecewise_survival_function
3+
4+
5+
def test_piecewise_hazard_function_scalar_and_array():
6+
breakpoints = [1.0, 2.0]
7+
hazard_rates = [0.5, 1.0, 1.5]
8+
# Scalar values
9+
assert piecewise_hazard_function(0.5, breakpoints, hazard_rates) == 0.5
10+
assert piecewise_hazard_function(1.5, breakpoints, hazard_rates) == 1.0
11+
assert piecewise_hazard_function(3.0, breakpoints, hazard_rates) == 1.5
12+
# Array values
13+
arr = np.array([0.5, 1.5, 3.0])
14+
np.testing.assert_allclose(
15+
piecewise_hazard_function(arr, breakpoints, hazard_rates),
16+
np.array([0.5, 1.0, 1.5]),
17+
)
18+
19+
20+
def test_piecewise_hazard_function_negative_time():
21+
"""Hazard should be zero for negative times."""
22+
breakpoints = [1.0, 2.0]
23+
hazard_rates = [0.5, 1.0, 1.5]
24+
assert piecewise_hazard_function(-1.0, breakpoints, hazard_rates) == 0
25+
np.testing.assert_array_equal(
26+
piecewise_hazard_function(np.array([-0.5, -2.0]), breakpoints, hazard_rates),
27+
np.array([0.0, 0.0]),
28+
)
29+
30+
31+
def test_piecewise_survival_function():
32+
breakpoints = [1.0, 2.0]
33+
hazard_rates = [0.5, 1.0, 1.5]
34+
# Known survival probabilities
35+
expected = np.exp(-np.array([0.0, 0.25, 1.0, 3.0]))
36+
times = np.array([0.0, 0.5, 1.5, 3.0])
37+
np.testing.assert_allclose(
38+
piecewise_survival_function(times, breakpoints, hazard_rates),
39+
expected,
40+
)
41+
42+
43+
def test_piecewise_survival_function_scalar_and_negative():
44+
breakpoints = [1.0, 2.0]
45+
hazard_rates = [0.5, 1.0, 1.5]
46+
# Scalar output should be a float
47+
val = piecewise_survival_function(1.5, breakpoints, hazard_rates)
48+
assert isinstance(val, float)
49+
assert np.isclose(val, np.exp(-1.0))
50+
# Negative times return survival of 1
51+
assert piecewise_survival_function(-2.0, breakpoints, hazard_rates) == 1

tests/test_summary_extra.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pandas as pd
2+
import pytest
3+
from gen_surv.summary import check_survival_data_quality, compare_survival_datasets
4+
from gen_surv import generate
5+
6+
7+
def test_check_survival_data_quality_fix_issues():
8+
df = pd.DataFrame(
9+
{
10+
"time": [1.0, -0.5, None, 1.0],
11+
"status": [1, 2, 0, 1],
12+
"id": [1, 2, 3, 1],
13+
}
14+
)
15+
fixed, issues = check_survival_data_quality(
16+
df,
17+
id_col="id",
18+
max_time=2.0,
19+
fix_issues=True,
20+
)
21+
assert issues["modifications"]["rows_dropped"] == 2
22+
assert issues["modifications"]["values_fixed"] == 1
23+
assert len(fixed) == 2
24+
25+
26+
def test_check_survival_data_quality_no_fix():
27+
"""Issues should be reported but data left unchanged when fix_issues=False."""
28+
df = pd.DataFrame({"time": [-1.0, 2.0], "status": [3, 1]})
29+
checked, issues = check_survival_data_quality(df, max_time=1.0, fix_issues=False)
30+
# Data is returned unmodified
31+
pd.testing.assert_frame_equal(df, checked)
32+
assert issues["invalid_values"]["negative_time"] == 1
33+
assert issues["invalid_values"]["excessive_time"] == 1
34+
assert issues["invalid_values"]["invalid_status"] == 1
35+
36+
37+
def test_compare_survival_datasets_basic():
38+
ds1 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0)
39+
ds2 = generate(model="cphm", n=5, model_cens="uniform", cens_par=1.0, beta=1.0, covariate_range=1.0)
40+
comparison = compare_survival_datasets({"A": ds1, "B": ds2})
41+
assert set(["A", "B"]).issubset(comparison.columns)
42+
assert "n_subjects" in comparison.index
43+
44+
45+
def test_compare_survival_datasets_with_covariates_and_empty_error():
46+
ds = generate(model="cphm", n=3, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=1.0)
47+
comparison = compare_survival_datasets({"only": ds}, covariate_cols=["X0"])
48+
assert "only" in comparison.columns
49+
assert "X0_mean" in comparison.index
50+
with pytest.raises(ValueError):
51+
compare_survival_datasets({})

tests/test_summary_more.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pandas as pd
2+
import pytest
3+
from gen_surv.summary import (
4+
summarize_survival_dataset,
5+
check_survival_data_quality,
6+
_print_summary,
7+
)
8+
9+
10+
def test_summarize_survival_dataset_errors():
11+
df = pd.DataFrame({"time": [1, 2], "status": [1, 0]})
12+
# Missing time column
13+
with pytest.raises(ValueError):
14+
summarize_survival_dataset(df.drop(columns=["time"]))
15+
# Missing ID column when specified
16+
with pytest.raises(ValueError):
17+
summarize_survival_dataset(df, id_col="id")
18+
# Missing covariate columns
19+
with pytest.raises(ValueError):
20+
summarize_survival_dataset(df, covariate_cols=["bad"])
21+
22+
23+
def test_summarize_survival_dataset_verbose_output(capsys):
24+
df = pd.DataFrame(
25+
{
26+
"time": [1.0, 2.0, 3.0],
27+
"status": [1, 0, 1],
28+
"id": [1, 2, 3],
29+
"age": [30, 40, 50],
30+
"group": ["A", "B", "A"],
31+
}
32+
)
33+
summary = summarize_survival_dataset(
34+
df, id_col="id", covariate_cols=["age", "group"]
35+
)
36+
_print_summary(summary, "time", "status", "id", ["age", "group"])
37+
captured = capsys.readouterr().out
38+
assert "SURVIVAL DATASET SUMMARY" in captured
39+
assert "age:" in captured
40+
assert "Categorical" in captured
41+
42+
43+
def test_check_survival_data_quality_duplicates_and_fix():
44+
df = pd.DataFrame(
45+
{
46+
"time": [1.0, -1.0, 2.0, 1.0],
47+
"status": [1, 1, 0, 1],
48+
"id": [1, 1, 2, 1],
49+
}
50+
)
51+
checked, issues = check_survival_data_quality(df, id_col="id", fix_issues=False)
52+
assert issues["duplicates"]["duplicate_rows"] == 1
53+
assert issues["duplicates"]["duplicate_ids"] == 2
54+
fixed, issues_fixed = check_survival_data_quality(
55+
df, id_col="id", max_time=2.0, fix_issues=True
56+
)
57+
assert len(fixed) < len(df)
58+
assert issues_fixed["modifications"]["rows_dropped"] > 0

tests/test_validate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,16 @@ def test_validate_gen_aft_log_normal_inputs_valid():
7171
def test_validate_dg_biv_inputs_valid_weibull():
7272
"""Valid parameters for a Weibull distribution should pass."""
7373
v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0])
74+
75+
76+
def test_validate_gen_aft_weibull_inputs_and_log_logistic():
77+
with pytest.raises(ValueError):
78+
v.validate_gen_aft_weibull_inputs(0, [0.1], 1.0, 1.0, "uniform", 1.0)
79+
with pytest.raises(ValueError):
80+
v.validate_gen_aft_log_logistic_inputs(1, [0.1], -1.0, 1.0, "uniform", 1.0)
81+
82+
83+
def test_validate_competing_risks_inputs():
84+
with pytest.raises(ValueError):
85+
v.validate_competing_risks_inputs(1, 2, [0.1], None, "uniform", 1.0)
86+
v.validate_competing_risks_inputs(1, 1, [0.5], [[0.1]], "uniform", 0.5)

0 commit comments

Comments
 (0)