Skip to content

Commit 7f7c28a

Browse files
committed
feat: Add competing risks models and enhance data visualization
1 parent 47e4d7f commit 7f7c28a

File tree

6 files changed

+606
-4
lines changed

6 files changed

+606
-4
lines changed

gen_surv/competing_risks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import pandas as pd
10-
from typing import Dict, List, Optional, Tuple, Union, Literal
10+
from typing import Dict, List, Optional, Tuple, Union, Literal, Any
1111

1212

1313
def gen_competing_risks(

gen_surv/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
>>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0)
77
"""
88

9-
from typing import Any, Dict, Literal, Optional, Union, List, Tuple, cast
9+
from typing import Any, Literal
1010
import pandas as pd
1111

1212
from gen_surv.cphm import gen_cphm

gen_surv/mixture.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
"""
2+
Mixture Cure Models for survival data simulation.
3+
4+
This module provides functions to generate survival data with a cure fraction,
5+
i.e., a proportion of subjects who are immune to the event of interest.
6+
"""
7+
8+
import numpy as np
9+
import pandas as pd
10+
from typing import Dict, List, Optional, Tuple, Union, Literal
11+
12+
13+
def gen_mixture_cure(
14+
n: int,
15+
cure_fraction: float,
16+
baseline_hazard: float = 0.5,
17+
betas_survival: Optional[List[float]] = None,
18+
betas_cure: Optional[List[float]] = None,
19+
n_covariates: int = 2,
20+
covariate_dist: Literal["normal", "uniform", "binary"] = "normal",
21+
covariate_params: Optional[Dict[str, Union[float, Tuple[float, float]]]] = None,
22+
model_cens: Literal["uniform", "exponential"] = "uniform",
23+
cens_par: float = 5.0,
24+
max_time: Optional[float] = 10.0,
25+
seed: Optional[int] = None
26+
) -> pd.DataFrame:
27+
"""
28+
Generate survival data with a cure fraction using a mixture cure model.
29+
30+
Parameters
31+
----------
32+
n : int
33+
Number of subjects.
34+
cure_fraction : float
35+
Baseline probability of being cured (immune to the event).
36+
Should be between 0 and 1.
37+
baseline_hazard : float, default=0.5
38+
Baseline hazard rate for the non-cured population.
39+
betas_survival : list of float, optional
40+
Coefficients for covariates in the survival component.
41+
If None, generates random coefficients.
42+
betas_cure : list of float, optional
43+
Coefficients for covariates in the cure component.
44+
If None, generates random coefficients.
45+
n_covariates : int, default=2
46+
Number of covariates to generate if betas is None.
47+
covariate_dist : {"normal", "uniform", "binary"}, default="normal"
48+
Distribution to generate covariates from.
49+
covariate_params : dict, optional
50+
Parameters for covariate distribution:
51+
- "normal": {"mean": float, "std": float}
52+
- "uniform": {"low": float, "high": float}
53+
- "binary": {"p": float}
54+
If None, uses defaults based on distribution.
55+
model_cens : {"uniform", "exponential"}, default="uniform"
56+
Censoring mechanism.
57+
cens_par : float, default=5.0
58+
Parameter for censoring distribution.
59+
max_time : float, optional, default=10.0
60+
Maximum simulation time. Set to None for no limit.
61+
seed : int, optional
62+
Random seed for reproducibility.
63+
64+
Returns
65+
-------
66+
pd.DataFrame
67+
DataFrame with columns:
68+
- "id": Subject identifier
69+
- "time": Time to event or censoring
70+
- "status": Event indicator (1=event, 0=censored)
71+
- "cured": Indicator of cure status (1=cured, 0=not cured)
72+
- "X0", "X1", ...: Covariates
73+
74+
Examples
75+
--------
76+
>>> from gen_surv.mixture import gen_mixture_cure
77+
>>>
78+
>>> # Generate data with 30% baseline cure fraction
79+
>>> df = gen_mixture_cure(
80+
... n=100,
81+
... cure_fraction=0.3,
82+
... betas_survival=[0.8, -0.5],
83+
... betas_cure=[-0.5, 0.8],
84+
... seed=42
85+
... )
86+
>>>
87+
>>> # Check cure proportion
88+
>>> print(f"Cured subjects: {df['cured'].mean():.2%}")
89+
"""
90+
if seed is not None:
91+
np.random.seed(seed)
92+
93+
# Validate inputs
94+
if not 0 <= cure_fraction <= 1:
95+
raise ValueError("cure_fraction must be between 0 and 1")
96+
97+
if baseline_hazard <= 0:
98+
raise ValueError("baseline_hazard must be positive")
99+
100+
# Set default covariate parameters if not provided
101+
if covariate_params is None:
102+
if covariate_dist == "normal":
103+
covariate_params = {"mean": 0.0, "std": 1.0}
104+
elif covariate_dist == "uniform":
105+
covariate_params = {"low": 0.0, "high": 1.0}
106+
elif covariate_dist == "binary":
107+
covariate_params = {"p": 0.5}
108+
else:
109+
raise ValueError(f"Unknown covariate distribution: {covariate_dist}")
110+
111+
# Set default betas if not provided
112+
if betas_survival is None:
113+
betas_survival = np.random.normal(0, 0.5, size=n_covariates)
114+
else:
115+
betas_survival = np.array(betas_survival)
116+
n_covariates = len(betas_survival)
117+
118+
if betas_cure is None:
119+
betas_cure = np.random.normal(0, 0.5, size=n_covariates)
120+
else:
121+
betas_cure = np.array(betas_cure)
122+
if len(betas_cure) != n_covariates:
123+
raise ValueError(
124+
f"betas_cure must have the same length as betas_survival, "
125+
f"got {len(betas_cure)} vs {n_covariates}"
126+
)
127+
128+
# Generate covariates
129+
if covariate_dist == "normal":
130+
X = np.random.normal(
131+
covariate_params.get("mean", 0.0),
132+
covariate_params.get("std", 1.0),
133+
size=(n, n_covariates)
134+
)
135+
elif covariate_dist == "uniform":
136+
X = np.random.uniform(
137+
covariate_params.get("low", 0.0),
138+
covariate_params.get("high", 1.0),
139+
size=(n, n_covariates)
140+
)
141+
elif covariate_dist == "binary":
142+
X = np.random.binomial(
143+
1,
144+
covariate_params.get("p", 0.5),
145+
size=(n, n_covariates)
146+
)
147+
else:
148+
raise ValueError(f"Unknown covariate distribution: {covariate_dist}")
149+
150+
# Calculate linear predictors
151+
lp_survival = X @ betas_survival
152+
lp_cure = X @ betas_cure
153+
154+
# Determine cure status (logistic model)
155+
cure_probs = 1 / (1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure)))
156+
cured = np.random.binomial(1, cure_probs)
157+
158+
# Generate survival times
159+
survival_times = np.zeros(n)
160+
161+
# For non-cured subjects, generate event times
162+
non_cured_indices = np.where(cured == 0)[0]
163+
164+
for i in non_cured_indices:
165+
# Adjust hazard rate by covariate effect
166+
adjusted_hazard = baseline_hazard * np.exp(lp_survival[i])
167+
168+
# Generate exponential survival time
169+
survival_times[i] = np.random.exponential(scale=1/adjusted_hazard)
170+
171+
# For cured subjects, set "infinite" survival time
172+
cured_indices = np.where(cured == 1)[0]
173+
if max_time is not None:
174+
survival_times[cured_indices] = max_time * 100 # Effectively infinite
175+
else:
176+
survival_times[cured_indices] = np.inf # Actually infinite
177+
178+
# Generate censoring times
179+
if model_cens == "uniform":
180+
cens_times = np.random.uniform(0, cens_par, size=n)
181+
elif model_cens == "exponential":
182+
cens_times = np.random.exponential(scale=cens_par, size=n)
183+
else:
184+
raise ValueError("model_cens must be 'uniform' or 'exponential'")
185+
186+
# Determine observed time and status
187+
observed_times = np.minimum(survival_times, cens_times)
188+
status = (survival_times <= cens_times).astype(int)
189+
190+
# Cap times at max_time if specified
191+
if max_time is not None:
192+
over_max = observed_times > max_time
193+
observed_times[over_max] = max_time
194+
status[over_max] = 0 # Censored if beyond max_time
195+
196+
# Create DataFrame
197+
data = pd.DataFrame({
198+
"id": np.arange(n),
199+
"time": observed_times,
200+
"status": status,
201+
"cured": cured
202+
})
203+
204+
# Add covariates
205+
for j in range(n_covariates):
206+
data[f"X{j}"] = X[:, j]
207+
208+
return data
209+
210+
211+
def cure_fraction_estimate(
212+
data: pd.DataFrame,
213+
time_col: str = "time",
214+
status_col: str = "status",
215+
bandwidth: float = 0.1
216+
) -> float:
217+
"""
218+
Estimate the cure fraction from observed data using non-parametric methods.
219+
220+
Parameters
221+
----------
222+
data : pd.DataFrame
223+
DataFrame with survival data.
224+
time_col : str, default="time"
225+
Name of the time column.
226+
status_col : str, default="status"
227+
Name of the status column (1=event, 0=censored).
228+
bandwidth : float, default=0.1
229+
Bandwidth parameter for smoothing the tail of the survival curve.
230+
231+
Returns
232+
-------
233+
float
234+
Estimated cure fraction.
235+
236+
Notes
237+
-----
238+
This function uses a non-parametric approach to estimate the cure fraction
239+
based on the plateau of the survival curve. It may not be accurate for
240+
small sample sizes or heavy censoring.
241+
"""
242+
# Sort data by time
243+
sorted_data = data.sort_values(by=time_col).copy()
244+
245+
# Calculate Kaplan-Meier estimate
246+
times = sorted_data[time_col].values
247+
status = sorted_data[status_col].values
248+
n = len(times)
249+
250+
if n == 0:
251+
return 0.0
252+
253+
# Calculate survival function
254+
survival = np.ones(n)
255+
256+
for i in range(n):
257+
if i > 0:
258+
survival[i] = survival[i-1]
259+
260+
# Count subjects at risk at this time
261+
at_risk = n - i
262+
263+
if status[i] == 1: # Event
264+
survival[i] *= (1 - 1/at_risk)
265+
266+
# Estimate cure fraction as the plateau of the survival curve
267+
# Use the last 10% of the survival curve if enough data points
268+
tail_size = max(int(n * 0.1), 1)
269+
tail_survival = survival[-tail_size:]
270+
271+
# Apply smoothing if there are enough data points
272+
if tail_size > 3:
273+
# Use kernel smoothing
274+
weights = np.exp(-(np.arange(tail_size) - tail_size + 1)**2 / (2 * bandwidth * tail_size)**2)
275+
weights = weights / weights.sum()
276+
cure_fraction = np.sum(tail_survival * weights)
277+
else:
278+
# Just use the last survival probability
279+
cure_fraction = survival[-1]
280+
281+
return cure_fraction

0 commit comments

Comments
 (0)