Skip to content

Commit 42d59df

Browse files
fix: tighten sequence checks and streamline competing-risk helpers (#81)
1 parent b0633d3 commit 42d59df

34 files changed

+994
-871
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ jobs:
4141
- name: Install dependencies
4242
run: poetry install --with dev --no-interaction --no-root
4343

44+
- name: Check runtime dependencies
45+
run: |
46+
poetry run python - <<'PY'
47+
import importlib, sys
48+
missing = [m for m in ("pandas",) if importlib.util.find_spec(m) is None]
49+
if missing:
50+
print("Missing dependencies:", ", ".join(missing))
51+
sys.exit(1)
52+
PY
53+
54+
- name: Verify version matches tag
55+
if: startsWith(github.ref, 'refs/tags/')
56+
run: python scripts/check_version_match.py
57+
4458
- name: Run tests with coverage
4559
run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term
4660

benchmarks/test_validation_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
pytest.importorskip("pytest_benchmark")
55

6-
from gen_surv._validation import ensure_positive_sequence
6+
from gen_surv.validation import ensure_positive_sequence
77

88

99
def test_positive_sequence_benchmark(benchmark):

docs/source/api/index.md

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,85 +14,76 @@ or `pip install scikit-survival` to enable this functionality.
1414

1515
## Core Interface
1616

17-
```{eval-rst}
18-
.. automodule:: gen_surv.interface
19-
:members:
20-
:undoc-members:
21-
:show-inheritance:
22-
```
17+
::: gen_surv.interface
18+
options:
19+
members: true
20+
undoc-members: true
21+
show-inheritance: true
2322

2423
## Model Generators
2524

2625
### Cox Proportional Hazards Model
27-
```{eval-rst}
28-
.. automodule:: gen_surv.cphm
29-
:members:
30-
:undoc-members:
31-
:show-inheritance:
32-
```
26+
::: gen_surv.cphm
27+
options:
28+
members: true
29+
undoc-members: true
30+
show-inheritance: true
3331

3432
### Accelerated Failure Time Models
35-
```{eval-rst}
36-
.. automodule:: gen_surv.aft
37-
:members:
38-
:undoc-members:
39-
:show-inheritance:
40-
```
33+
::: gen_surv.aft
34+
options:
35+
members: true
36+
undoc-members: true
37+
show-inheritance: true
4138

4239
### Continuous-Time Markov Models
43-
```{eval-rst}
44-
.. automodule:: gen_surv.cmm
45-
:members:
46-
:undoc-members:
47-
:show-inheritance:
48-
```
40+
::: gen_surv.cmm
41+
options:
42+
members: true
43+
undoc-members: true
44+
show-inheritance: true
4945

5046
### Time-Dependent Covariate Models
51-
```{eval-rst}
52-
.. automodule:: gen_surv.tdcm
53-
:members:
54-
:undoc-members:
55-
:show-inheritance:
56-
```
47+
::: gen_surv.tdcm
48+
options:
49+
members: true
50+
undoc-members: true
51+
show-inheritance: true
5752

5853
### Time-Homogeneous Markov Models
59-
```{eval-rst}
60-
.. automodule:: gen_surv.thmm
61-
:members:
62-
:undoc-members:
63-
:show-inheritance:
64-
```
54+
::: gen_surv.thmm
55+
options:
56+
members: true
57+
undoc-members: true
58+
show-inheritance: true
6559

6660
## Utility Functions
6761

6862
### Censoring Functions
69-
```{eval-rst}
70-
.. automodule:: gen_surv.censoring
71-
:members:
72-
:undoc-members:
73-
:show-inheritance:
74-
```
63+
::: gen_surv.censoring
64+
options:
65+
members: true
66+
undoc-members: true
67+
show-inheritance: true
7568

7669
### Bivariate Distributions
77-
```{eval-rst}
78-
.. automodule:: gen_surv.bivariate
79-
:members:
80-
:undoc-members:
81-
:show-inheritance:
82-
```
70+
::: gen_surv.bivariate
71+
options:
72+
members: true
73+
undoc-members: true
74+
show-inheritance: true
8375

8476
### Validation Functions
85-
```{eval-rst}
86-
.. automodule:: gen_surv.validate
87-
:members:
88-
:undoc-members:
89-
:show-inheritance:
90-
```
77+
::: gen_surv.validation
78+
options:
79+
members: true
80+
undoc-members: true
81+
show-inheritance: true
9182

9283
### Command Line Interface
93-
```{eval-rst}
94-
.. automodule:: gen_surv.cli
95-
:members:
96-
:undoc-members:
97-
:show-inheritance:
98-
```
84+
::: gen_surv.cli
85+
options:
86+
members: true
87+
undoc-members: true
88+
show-inheritance: true
89+

docs/source/modules.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,59 @@ orphan: true
88
options:
99
members: true
1010
undoc-members: true
11+
show-inheritance: true
1112

1213
::: gen_surv.cmm
1314
options:
1415
members: true
1516
undoc-members: true
17+
show-inheritance: true
1618

1719
::: gen_surv.tdcm
1820
options:
1921
members: true
2022
undoc-members: true
23+
show-inheritance: true
2124

2225
::: gen_surv.thmm
2326
options:
2427
members: true
2528
undoc-members: true
29+
show-inheritance: true
2630

2731
::: gen_surv.interface
2832
options:
2933
members: true
3034
undoc-members: true
35+
show-inheritance: true
3136

3237
::: gen_surv.aft
3338
options:
3439
members: true
3540
undoc-members: true
41+
show-inheritance: true
3642

3743
::: gen_surv.bivariate
3844
options:
3945
members: true
4046
undoc-members: true
47+
show-inheritance: true
4148

4249
::: gen_surv.censoring
4350
options:
4451
members: true
4552
undoc-members: true
53+
show-inheritance: true
4654

4755
::: gen_surv.cli
4856
options:
4957
members: true
5058
undoc-members: true
59+
show-inheritance: true
5160

52-
::: gen_surv.validate
61+
::: gen_surv.validation
5362
options:
5463
members: true
5564
undoc-members: true
65+
show-inheritance: true
66+

gen_surv/_covariates.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Utilities for generating covariate matrices with validation."""
22

3-
from typing import Literal, cast
3+
from typing import Literal
44

55
import numpy as np
6+
from numpy.random import Generator
67
from numpy.typing import NDArray
78

8-
from ._validation import ParameterError, ensure_positive
9+
from .validation import ParameterError, ensure_positive
910

10-
_CovParams = dict[str, float | tuple[float, float]]
11+
_CovParams = dict[str, float]
1112

1213

1314
def set_covariate_params(
@@ -30,37 +31,45 @@ def set_covariate_params(
3031
)
3132

3233

34+
def _get_float(params: _CovParams, key: str, default: float) -> float:
35+
val = params.get(key, default)
36+
if not isinstance(val, (int, float)):
37+
raise ParameterError(f"covariate_params['{key}']", val, "must be a number")
38+
return float(val)
39+
40+
3341
def generate_covariates(
3442
n: int,
3543
n_covariates: int,
3644
covariate_dist: Literal["normal", "uniform", "binary"],
3745
covariate_params: _CovParams,
46+
rng: Generator,
3847
) -> NDArray[np.float64]:
3948
"""Generate covariate matrix according to the specified distribution."""
4049
if covariate_dist == "normal":
41-
std = cast(float, covariate_params.get("std", 1.0))
50+
std = _get_float(covariate_params, "std", 1.0)
4251
ensure_positive(std, "covariate_params['std']")
43-
mean = cast(float, covariate_params.get("mean", 0.0))
44-
return np.random.normal(mean, std, size=(n, n_covariates))
52+
mean = _get_float(covariate_params, "mean", 0.0)
53+
return rng.normal(mean, std, size=(n, n_covariates))
4554
if covariate_dist == "uniform":
46-
low = cast(float, covariate_params.get("low", 0.0))
47-
high = cast(float, covariate_params.get("high", 1.0))
55+
low = _get_float(covariate_params, "low", 0.0)
56+
high = _get_float(covariate_params, "high", 1.0)
4857
if high <= low:
4958
raise ParameterError(
5059
"covariate_params['high']",
5160
high,
5261
"must be greater than 'low'",
5362
)
54-
return np.random.uniform(low, high, size=(n, n_covariates))
63+
return rng.uniform(low, high, size=(n, n_covariates))
5564
if covariate_dist == "binary":
56-
p = cast(float, covariate_params.get("p", 0.5))
65+
p = _get_float(covariate_params, "p", 0.5)
5766
if not 0 <= p <= 1:
5867
raise ParameterError(
5968
"covariate_params['p']",
6069
p,
6170
"must be between 0 and 1",
6271
)
63-
return np.random.binomial(1, p, size=(n, n_covariates)).astype(float)
72+
return rng.binomial(1, p, size=(n, n_covariates)).astype(float)
6473
raise ParameterError(
6574
"covariate_dist",
6675
covariate_dist,

0 commit comments

Comments
 (0)