Skip to content

Commit 9fa0e5d

Browse files
feat: share covariate utilities (#73)
1 parent 1717ef9 commit 9fa0e5d

File tree

3 files changed

+147
-96
lines changed

3 files changed

+147
-96
lines changed

gen_surv/_covariates.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Utilities for generating covariate matrices with validation."""
2+
3+
from typing import Literal, cast
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
8+
from ._validation import ParameterError, ensure_positive
9+
10+
_CovParams = dict[str, float | tuple[float, float]]
11+
12+
13+
def set_covariate_params(
14+
covariate_dist: Literal["normal", "uniform", "binary"],
15+
covariate_params: _CovParams | None,
16+
) -> _CovParams:
17+
"""Return covariate distribution parameters with defaults filled in."""
18+
if covariate_params is not None:
19+
return covariate_params
20+
if covariate_dist == "normal":
21+
return {"mean": 0.0, "std": 1.0}
22+
if covariate_dist == "uniform":
23+
return {"low": 0.0, "high": 1.0}
24+
if covariate_dist == "binary":
25+
return {"p": 0.5}
26+
raise ParameterError(
27+
"covariate_dist",
28+
covariate_dist,
29+
"unsupported covariate distribution; choose from 'normal', 'uniform', or 'binary'",
30+
)
31+
32+
33+
def generate_covariates(
34+
n: int,
35+
n_covariates: int,
36+
covariate_dist: Literal["normal", "uniform", "binary"],
37+
covariate_params: _CovParams,
38+
) -> NDArray[np.float64]:
39+
"""Generate covariate matrix according to the specified distribution."""
40+
if covariate_dist == "normal":
41+
std = cast(float, covariate_params.get("std", 1.0))
42+
ensure_positive(std, "covariate_params['std']")
43+
mean = cast(float, covariate_params.get("mean", 0.0))
44+
return np.random.normal(mean, std, size=(n, n_covariates))
45+
if covariate_dist == "uniform":
46+
low = cast(float, covariate_params.get("low", 0.0))
47+
high = cast(float, covariate_params.get("high", 1.0))
48+
if high <= low:
49+
raise ParameterError(
50+
"covariate_params['high']",
51+
high,
52+
"must be greater than 'low'",
53+
)
54+
return np.random.uniform(low, high, size=(n, n_covariates))
55+
if covariate_dist == "binary":
56+
p = cast(float, covariate_params.get("p", 0.5))
57+
if not 0 <= p <= 1:
58+
raise ParameterError(
59+
"covariate_params['p']",
60+
p,
61+
"must be between 0 and 1",
62+
)
63+
return np.random.binomial(1, p, size=(n, n_covariates)).astype(float)
64+
raise ParameterError(
65+
"covariate_dist",
66+
covariate_dist,
67+
"unsupported covariate distribution; choose from 'normal', 'uniform', or 'binary'",
68+
)

gen_surv/mixture.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,22 @@
1111
import pandas as pd
1212
from numpy.typing import NDArray
1313

14-
_TAIL_FRACTION = 0.1
15-
_SMOOTH_MIN_TAIL = 3
14+
_TAIL_FRACTION: float = 0.1
15+
_SMOOTH_MIN_TAIL: int = 3
1616

17+
from ._covariates import generate_covariates, set_covariate_params
1718
from ._validation import (
1819
LengthError,
1920
ParameterError,
2021
ensure_censoring_model,
2122
ensure_in_choices,
23+
ensure_numeric_sequence,
2224
ensure_positive,
25+
ensure_positive_int,
2326
)
2427
from .censoring import rexpocens, runifcens
2528

2629

27-
def _set_covariate_params(
28-
covariate_dist: str,
29-
covariate_params: dict[str, float | tuple[float, float]] | None,
30-
) -> dict[str, float | tuple[float, float]]:
31-
if covariate_params is not None:
32-
return covariate_params
33-
if covariate_dist == "normal":
34-
return {"mean": 0.0, "std": 1.0}
35-
if covariate_dist == "uniform":
36-
return {"low": 0.0, "high": 1.0}
37-
if covariate_dist == "binary":
38-
return {"p": 0.5}
39-
raise ParameterError(
40-
"covariate_dist", covariate_dist, "must be one of {'normal','uniform','binary'}"
41-
)
42-
43-
4430
def _prepare_betas(
4531
betas_survival: list[float] | None,
4632
betas_cure: list[float] | None,
@@ -49,49 +35,26 @@ def _prepare_betas(
4935
if betas_survival is None:
5036
betas_survival_arr = np.random.normal(0, 0.5, size=n_covariates)
5137
else:
38+
ensure_numeric_sequence(betas_survival, "betas_survival")
5239
betas_survival_arr = np.asarray(betas_survival, dtype=float)
5340
n_covariates = len(betas_survival_arr)
5441

5542
if betas_cure is None:
5643
betas_cure_arr = np.random.normal(0, 0.5, size=n_covariates)
5744
else:
45+
ensure_numeric_sequence(betas_cure, "betas_cure")
5846
betas_cure_arr = np.asarray(betas_cure, dtype=float)
5947
if len(betas_cure_arr) != n_covariates:
6048
raise LengthError("betas_cure", len(betas_cure_arr), n_covariates)
6149

6250
return betas_survival_arr, betas_cure_arr, n_covariates
6351

6452

65-
def _generate_covariates(
66-
n: int,
67-
n_covariates: int,
68-
covariate_dist: str,
69-
covariate_params: dict[str, float | tuple[float, float]],
70-
) -> NDArray[np.float64]:
71-
if covariate_dist == "normal":
72-
return np.random.normal(
73-
covariate_params.get("mean", 0.0),
74-
covariate_params.get("std", 1.0),
75-
size=(n, n_covariates),
76-
)
77-
if covariate_dist == "uniform":
78-
return np.random.uniform(
79-
covariate_params.get("low", 0.0),
80-
covariate_params.get("high", 1.0),
81-
size=(n, n_covariates),
82-
)
83-
if covariate_dist == "binary":
84-
return np.random.binomial(
85-
1, covariate_params.get("p", 0.5), size=(n, n_covariates)
86-
).astype(float)
87-
raise ParameterError(
88-
"covariate_dist", covariate_dist, "must be one of {'normal','uniform','binary'}"
89-
)
90-
91-
9253
def _cure_status(
9354
lp_cure: NDArray[np.float64], cure_fraction: float
9455
) -> NDArray[np.int64]:
56+
if not 0 < cure_fraction < 1:
57+
raise ParameterError("cure_fraction", cure_fraction, "must be between 0 and 1")
9558
cure_probs = 1 / (
9659
1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure))
9760
)
@@ -104,6 +67,9 @@ def _survival_times(
10467
baseline_hazard: float,
10568
max_time: float | None,
10669
) -> NDArray[np.float64]:
70+
ensure_positive(baseline_hazard, "baseline_hazard")
71+
if max_time is not None:
72+
ensure_positive(max_time, "max_time")
10773
n = cured.size
10874
times = np.zeros(n, dtype=float)
10975
non_cured = cured == 0
@@ -122,6 +88,10 @@ def _apply_censoring(
12288
cens_par: float,
12389
max_time: float | None,
12490
) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
91+
ensure_censoring_model(model_cens)
92+
ensure_positive(cens_par, "cens_par")
93+
if max_time is not None:
94+
ensure_positive(max_time, "max_time")
12595
rfunc = runifcens if model_cens == "uniform" else rexpocens
12696
cens_times = rfunc(len(survival_times), cens_par)
12797
observed = np.minimum(survival_times, cens_times)
@@ -213,16 +183,21 @@ def gen_mixture_cure(
213183
if seed is not None:
214184
np.random.seed(seed)
215185

186+
ensure_positive_int(n, "n")
187+
ensure_positive_int(n_covariates, "n_covariates")
188+
ensure_positive(baseline_hazard, "baseline_hazard")
189+
ensure_positive(cens_par, "cens_par")
190+
if max_time is not None:
191+
ensure_positive(max_time, "max_time")
216192
if not 0 <= cure_fraction <= 1:
217193
raise ParameterError("cure_fraction", cure_fraction, "must be between 0 and 1")
218-
ensure_positive(baseline_hazard, "baseline_hazard")
219194

220195
ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"})
221-
covariate_params = _set_covariate_params(covariate_dist, covariate_params)
196+
covariate_params = set_covariate_params(covariate_dist, covariate_params)
222197
betas_survival_arr, betas_cure_arr, n_covariates = _prepare_betas(
223198
betas_survival, betas_cure, n_covariates
224199
)
225-
X = _generate_covariates(n, n_covariates, covariate_dist, covariate_params)
200+
X = generate_covariates(n, n_covariates, covariate_dist, covariate_params)
226201
lp_survival = X @ betas_survival_arr
227202
lp_cure = X @ betas_cure_arr
228203
cured = _cure_status(lp_cure, cure_fraction)
@@ -274,6 +249,14 @@ def cure_fraction_estimate(
274249
based on the plateau of the survival curve. It may not be accurate for
275250
small sample sizes or heavy censoring.
276251
"""
252+
if time_col not in data.columns or status_col not in data.columns:
253+
missing = [c for c in (time_col, status_col) if c not in data.columns]
254+
raise ParameterError(
255+
"data",
256+
data.columns.tolist(),
257+
f"missing required column(s): {', '.join(missing)}",
258+
)
259+
ensure_positive(bandwidth, "bandwidth")
277260
# Sort data by time
278261
sorted_data = data.sort_values(by=time_col).copy()
279262

gen_surv/piecewise.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,52 @@
55
exponential distributions with time-dependent hazards.
66
"""
77

8-
from typing import Dict, List, Literal, Optional, Tuple, Union
8+
from typing import Literal
99

1010
import numpy as np
1111
import pandas as pd
12+
from numpy.typing import NDArray
1213

14+
from ._covariates import generate_covariates, set_covariate_params
1315
from ._validation import (
1416
ParameterError,
1517
ensure_censoring_model,
1618
ensure_in_choices,
19+
ensure_numeric_sequence,
20+
ensure_positive,
21+
ensure_positive_int,
1722
ensure_positive_sequence,
1823
ensure_sequence_length,
1924
)
2025
from .censoring import rexpocens, runifcens
2126

2227

28+
def _validate_piecewise_params(
29+
breakpoints: list[float], hazard_rates: list[float]
30+
) -> None:
31+
"""Validate breakpoint and hazard rate sequences."""
32+
ensure_sequence_length(hazard_rates, len(breakpoints) + 1, "hazard_rates")
33+
ensure_positive_sequence(breakpoints, "breakpoints")
34+
ensure_positive_sequence(hazard_rates, "hazard_rates")
35+
if np.any(np.diff(breakpoints) <= 0):
36+
raise ParameterError(
37+
"breakpoints",
38+
breakpoints,
39+
"must be a strictly increasing sequence",
40+
)
41+
42+
2343
def gen_piecewise_exponential(
2444
n: int,
25-
breakpoints: List[float],
26-
hazard_rates: List[float],
27-
betas: Optional[Union[List[float], np.ndarray]] = None,
45+
breakpoints: list[float],
46+
hazard_rates: list[float],
47+
betas: list[float] | NDArray[np.float64] | None = None,
2848
n_covariates: int = 2,
2949
covariate_dist: Literal["normal", "uniform", "binary"] = "normal",
30-
covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None,
50+
covariate_params: dict[str, float | tuple[float, float]] | None = None,
3151
model_cens: Literal["uniform", "exponential"] = "uniform",
3252
cens_par: float = 5.0,
33-
seed: Optional[int] = None,
53+
seed: int | None = None,
3454
) -> pd.DataFrame:
3555
"""
3656
Generate survival data using a piecewise exponential distribution.
@@ -88,55 +108,27 @@ def gen_piecewise_exponential(
88108
if seed is not None:
89109
np.random.seed(seed)
90110

111+
ensure_positive_int(n, "n")
112+
ensure_positive_int(n_covariates, "n_covariates")
113+
ensure_positive(cens_par, "cens_par")
114+
91115
# Validate inputs
92-
ensure_sequence_length(hazard_rates, len(breakpoints) + 1, "hazard_rates")
93-
ensure_positive_sequence(breakpoints, "breakpoints")
94-
ensure_positive_sequence(hazard_rates, "hazard_rates")
95-
if np.any(np.diff(breakpoints) <= 0):
96-
raise ParameterError("breakpoints", breakpoints, "must be in ascending order")
116+
_validate_piecewise_params(breakpoints, hazard_rates)
97117

98118
ensure_censoring_model(model_cens)
99119
ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"})
100-
101-
# Set default covariate parameters if not provided
102-
if covariate_params is None:
103-
if covariate_dist == "normal":
104-
covariate_params = {"mean": 0.0, "std": 1.0}
105-
elif covariate_dist == "uniform":
106-
covariate_params = {"low": 0.0, "high": 1.0}
107-
elif covariate_dist == "binary":
108-
covariate_params = {"p": 0.5}
120+
covariate_params = set_covariate_params(covariate_dist, covariate_params)
109121

110122
# Set default betas if not provided
111123
if betas is None:
112124
betas = np.random.normal(0, 0.5, size=n_covariates)
113125
else:
114-
betas = np.array(betas)
126+
ensure_numeric_sequence(betas, "betas")
127+
betas = np.array(betas, dtype=float)
115128
n_covariates = len(betas)
116129

117130
# Generate covariates
118-
if covariate_dist == "normal":
119-
X = np.random.normal(
120-
covariate_params.get("mean", 0.0),
121-
covariate_params.get("std", 1.0),
122-
size=(n, n_covariates),
123-
)
124-
elif covariate_dist == "uniform":
125-
X = np.random.uniform(
126-
covariate_params.get("low", 0.0),
127-
covariate_params.get("high", 1.0),
128-
size=(n, n_covariates),
129-
)
130-
elif covariate_dist == "binary":
131-
X = np.random.binomial(
132-
1, covariate_params.get("p", 0.5), size=(n, n_covariates)
133-
)
134-
else: # pragma: no cover - validated above
135-
raise ParameterError(
136-
"covariate_dist",
137-
covariate_dist,
138-
"must be one of {'normal', 'uniform', 'binary'}",
139-
)
131+
X = generate_covariates(n, n_covariates, covariate_dist, covariate_params)
140132

141133
# Calculate linear predictor
142134
linear_predictor = X @ betas
@@ -209,8 +201,10 @@ def gen_piecewise_exponential(
209201

210202

211203
def piecewise_hazard_function(
212-
t: Union[float, np.ndarray], breakpoints: List[float], hazard_rates: List[float]
213-
) -> Union[float, np.ndarray]:
204+
t: float | NDArray[np.float64],
205+
breakpoints: list[float],
206+
hazard_rates: list[float],
207+
) -> float | NDArray[np.float64]:
214208
"""
215209
Calculate the hazard function value at time t for a piecewise exponential distribution.
216210
@@ -228,6 +222,8 @@ def piecewise_hazard_function(
228222
float or array
229223
Hazard function value(s) at time t.
230224
"""
225+
_validate_piecewise_params(breakpoints, hazard_rates)
226+
231227
# Convert scalar input to array for consistent processing
232228
scalar_input = np.isscalar(t)
233229
t_array = np.atleast_1d(t)
@@ -253,8 +249,10 @@ def piecewise_hazard_function(
253249

254250

255251
def piecewise_survival_function(
256-
t: Union[float, np.ndarray], breakpoints: List[float], hazard_rates: List[float]
257-
) -> Union[float, np.ndarray]:
252+
t: float | NDArray[np.float64],
253+
breakpoints: list[float],
254+
hazard_rates: list[float],
255+
) -> float | NDArray[np.float64]:
258256
"""
259257
Calculate the survival function at time t for a piecewise exponential distribution.
260258
@@ -272,6 +270,8 @@ def piecewise_survival_function(
272270
float or array
273271
Survival function value(s) at time t.
274272
"""
273+
_validate_piecewise_params(breakpoints, hazard_rates)
274+
275275
# Convert scalar input to array for consistent processing
276276
scalar_input = np.isscalar(t)
277277
t_array = np.atleast_1d(t)

0 commit comments

Comments
 (0)