55exponential distributions with time-dependent hazards.
66"""
77
8- from typing import Dict , List , Literal , Optional , Tuple , Union
8+ from typing import Literal
99
1010import numpy as np
1111import pandas as pd
12+ from numpy .typing import NDArray
1213
14+ from ._covariates import generate_covariates , set_covariate_params
1315from ._validation import (
1416 ParameterError ,
1517 ensure_censoring_model ,
1618 ensure_in_choices ,
19+ ensure_numeric_sequence ,
20+ ensure_positive ,
21+ ensure_positive_int ,
1722 ensure_positive_sequence ,
1823 ensure_sequence_length ,
1924)
2025from .censoring import rexpocens , runifcens
2126
2227
28+ def _validate_piecewise_params (
29+ breakpoints : list [float ], hazard_rates : list [float ]
30+ ) -> None :
31+ """Validate breakpoint and hazard rate sequences."""
32+ ensure_sequence_length (hazard_rates , len (breakpoints ) + 1 , "hazard_rates" )
33+ ensure_positive_sequence (breakpoints , "breakpoints" )
34+ ensure_positive_sequence (hazard_rates , "hazard_rates" )
35+ if np .any (np .diff (breakpoints ) <= 0 ):
36+ raise ParameterError (
37+ "breakpoints" ,
38+ breakpoints ,
39+ "must be a strictly increasing sequence" ,
40+ )
41+
42+
2343def gen_piecewise_exponential (
2444 n : int ,
25- breakpoints : List [float ],
26- hazard_rates : List [float ],
27- betas : Optional [ Union [ List [ float ], np .ndarray ]] = None ,
45+ breakpoints : list [float ],
46+ hazard_rates : list [float ],
47+ betas : list [ float ] | NDArray [ np .float64 ] | None = None ,
2848 n_covariates : int = 2 ,
2949 covariate_dist : Literal ["normal" , "uniform" , "binary" ] = "normal" ,
30- covariate_params : Optional [ Dict [ str , Union [ float , Tuple [float , float ]]]] = None ,
50+ covariate_params : dict [ str , float | tuple [float , float ]] | None = None ,
3151 model_cens : Literal ["uniform" , "exponential" ] = "uniform" ,
3252 cens_par : float = 5.0 ,
33- seed : Optional [ int ] = None ,
53+ seed : int | None = None ,
3454) -> pd .DataFrame :
3555 """
3656 Generate survival data using a piecewise exponential distribution.
@@ -88,55 +108,27 @@ def gen_piecewise_exponential(
88108 if seed is not None :
89109 np .random .seed (seed )
90110
111+ ensure_positive_int (n , "n" )
112+ ensure_positive_int (n_covariates , "n_covariates" )
113+ ensure_positive (cens_par , "cens_par" )
114+
91115 # Validate inputs
92- ensure_sequence_length (hazard_rates , len (breakpoints ) + 1 , "hazard_rates" )
93- ensure_positive_sequence (breakpoints , "breakpoints" )
94- ensure_positive_sequence (hazard_rates , "hazard_rates" )
95- if np .any (np .diff (breakpoints ) <= 0 ):
96- raise ParameterError ("breakpoints" , breakpoints , "must be in ascending order" )
116+ _validate_piecewise_params (breakpoints , hazard_rates )
97117
98118 ensure_censoring_model (model_cens )
99119 ensure_in_choices (covariate_dist , "covariate_dist" , {"normal" , "uniform" , "binary" })
100-
101- # Set default covariate parameters if not provided
102- if covariate_params is None :
103- if covariate_dist == "normal" :
104- covariate_params = {"mean" : 0.0 , "std" : 1.0 }
105- elif covariate_dist == "uniform" :
106- covariate_params = {"low" : 0.0 , "high" : 1.0 }
107- elif covariate_dist == "binary" :
108- covariate_params = {"p" : 0.5 }
120+ covariate_params = set_covariate_params (covariate_dist , covariate_params )
109121
110122 # Set default betas if not provided
111123 if betas is None :
112124 betas = np .random .normal (0 , 0.5 , size = n_covariates )
113125 else :
114- betas = np .array (betas )
126+ ensure_numeric_sequence (betas , "betas" )
127+ betas = np .array (betas , dtype = float )
115128 n_covariates = len (betas )
116129
117130 # Generate covariates
118- if covariate_dist == "normal" :
119- X = np .random .normal (
120- covariate_params .get ("mean" , 0.0 ),
121- covariate_params .get ("std" , 1.0 ),
122- size = (n , n_covariates ),
123- )
124- elif covariate_dist == "uniform" :
125- X = np .random .uniform (
126- covariate_params .get ("low" , 0.0 ),
127- covariate_params .get ("high" , 1.0 ),
128- size = (n , n_covariates ),
129- )
130- elif covariate_dist == "binary" :
131- X = np .random .binomial (
132- 1 , covariate_params .get ("p" , 0.5 ), size = (n , n_covariates )
133- )
134- else : # pragma: no cover - validated above
135- raise ParameterError (
136- "covariate_dist" ,
137- covariate_dist ,
138- "must be one of {'normal', 'uniform', 'binary'}" ,
139- )
131+ X = generate_covariates (n , n_covariates , covariate_dist , covariate_params )
140132
141133 # Calculate linear predictor
142134 linear_predictor = X @ betas
@@ -209,8 +201,10 @@ def gen_piecewise_exponential(
209201
210202
211203def piecewise_hazard_function (
212- t : Union [float , np .ndarray ], breakpoints : List [float ], hazard_rates : List [float ]
213- ) -> Union [float , np .ndarray ]:
204+ t : float | NDArray [np .float64 ],
205+ breakpoints : list [float ],
206+ hazard_rates : list [float ],
207+ ) -> float | NDArray [np .float64 ]:
214208 """
215209 Calculate the hazard function value at time t for a piecewise exponential distribution.
216210
@@ -228,6 +222,8 @@ def piecewise_hazard_function(
228222 float or array
229223 Hazard function value(s) at time t.
230224 """
225+ _validate_piecewise_params (breakpoints , hazard_rates )
226+
231227 # Convert scalar input to array for consistent processing
232228 scalar_input = np .isscalar (t )
233229 t_array = np .atleast_1d (t )
@@ -253,8 +249,10 @@ def piecewise_hazard_function(
253249
254250
255251def piecewise_survival_function (
256- t : Union [float , np .ndarray ], breakpoints : List [float ], hazard_rates : List [float ]
257- ) -> Union [float , np .ndarray ]:
252+ t : float | NDArray [np .float64 ],
253+ breakpoints : list [float ],
254+ hazard_rates : list [float ],
255+ ) -> float | NDArray [np .float64 ]:
258256 """
259257 Calculate the survival function at time t for a piecewise exponential distribution.
260258
@@ -272,6 +270,8 @@ def piecewise_survival_function(
272270 float or array
273271 Survival function value(s) at time t.
274272 """
273+ _validate_piecewise_params (breakpoints , hazard_rates )
274+
275275 # Convert scalar input to array for consistent processing
276276 scalar_input = np .isscalar (t )
277277 t_array = np .atleast_1d (t )
0 commit comments