22Example demonstrating the Competing Risks models and visualization.
33"""
44
5- import sys
65import os
6+ import sys
7+
78import matplotlib .pyplot as plt
89import numpy as np
910import pandas as pd
1011
11- sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' )))
12+ sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." )))
1213
1314from gen_surv import generate
14- from gen_surv .competing_risks import gen_competing_risks , gen_competing_risks_weibull , cause_specific_cumulative_incidence
15- from gen_surv .summary import summarize_survival_dataset , compare_survival_datasets
15+ from gen_surv .competing_risks import (
16+ cause_specific_cumulative_incidence ,
17+ gen_competing_risks ,
18+ gen_competing_risks_weibull ,
19+ )
20+ from gen_surv .summary import compare_survival_datasets , summarize_survival_dataset
1621
1722
1823def plot_cause_specific_cumulative_incidence (df , time_points = None , figsize = (10 , 6 )):
1924 """Plot the cause-specific cumulative incidence functions."""
2025 if time_points is None :
2126 max_time = df ["time" ].max ()
2227 time_points = np .linspace (0 , max_time , 100 )
23-
28+
2429 # Get unique causes (excluding censoring)
2530 causes = sorted ([c for c in df ["status" ].unique () if c > 0 ])
26-
31+
2732 # Create the plot
2833 fig , ax = plt .subplots (figsize = figsize )
29-
34+
3035 for cause in causes :
3136 cif = cause_specific_cumulative_incidence (df , time_points , cause = cause )
3237 ax .plot (cif ["time" ], cif ["incidence" ], label = f"Cause { cause } " )
33-
38+
3439 # Add overlay showing number of subjects at each time
3540 time_bins = np .linspace (0 , df ["time" ].max (), 10 )
3641 event_counts = np .histogram (df .loc [df ["status" ] > 0 , "time" ], bins = time_bins )[0 ]
37-
42+
3843 # Add a secondary y-axis for event counts
3944 ax2 = ax .twinx ()
40- ax2 .bar (time_bins [:- 1 ], event_counts , width = time_bins [1 ]- time_bins [0 ],
41- alpha = 0.2 , color = 'gray' , align = 'edge' )
42- ax2 .set_ylabel ('Number of events' )
45+ ax2 .bar (
46+ time_bins [:- 1 ],
47+ event_counts ,
48+ width = time_bins [1 ] - time_bins [0 ],
49+ alpha = 0.2 ,
50+ color = "gray" ,
51+ align = "edge" ,
52+ )
53+ ax2 .set_ylabel ("Number of events" )
4354 ax2 .grid (False )
44-
55+
4556 # Format the main plot
4657 ax .set_xlabel ("Time" )
4758 ax .set_ylabel ("Cumulative Incidence" )
4859 ax .set_title ("Cause-Specific Cumulative Incidence Functions" )
4960 ax .legend ()
5061 ax .grid (alpha = 0.3 )
51-
62+
5263 return fig , ax
5364
5465
@@ -61,7 +72,7 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
6172 betas = [[0.8 , - 0.5 ], [0.2 , 0.7 ]],
6273 model_cens = "uniform" ,
6374 cens_par = 2.0 ,
64- seed = 42
75+ seed = 42 ,
6576)
6677
6778# 2. Generate data with Weibull hazards (different shapes)
@@ -74,7 +85,7 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
7485 betas = [[0.8 , - 0.5 ], [0.2 , 0.7 ]],
7586 model_cens = "uniform" ,
7687 cens_par = 2.0 ,
77- seed = 42
88+ seed = 42 ,
7889)
7990
8091# 3. Print summary statistics for both datasets
@@ -96,17 +107,13 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
96107time_points = np .linspace (0 , 5 , 100 )
97108
98109fig1 , ax1 = plot_cause_specific_cumulative_incidence (
99- data_exponential ,
100- time_points = time_points ,
101- figsize = (10 , 6 )
110+ data_exponential , time_points = time_points , figsize = (10 , 6 )
102111)
103112plt .title ("Cumulative Incidence Functions (Exponential Hazards)" )
104113plt .savefig ("cr_exponential_cif.png" , dpi = 300 , bbox_inches = "tight" )
105114
106115fig2 , ax2 = plot_cause_specific_cumulative_incidence (
107- data_weibull ,
108- time_points = time_points ,
109- figsize = (10 , 6 )
116+ data_weibull , time_points = time_points , figsize = (10 , 6 )
110117)
111118plt .title ("Cumulative Incidence Functions (Weibull Hazards)" )
112119plt .savefig ("cr_weibull_cif.png" , dpi = 300 , bbox_inches = "tight" )
@@ -121,16 +128,15 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
121128 betas = [[0.8 , - 0.5 ], [0.2 , 0.7 ]],
122129 model_cens = "uniform" ,
123130 cens_par = 2.0 ,
124- seed = 42
131+ seed = 42 ,
125132)
126133print (data_unified .head ())
127134
128135# 7. Compare datasets
129136print ("\n Comparing datasets:" )
130- comparison = compare_survival_datasets ({
131- "Exponential" : data_exponential ,
132- "Weibull" : data_weibull
133- })
137+ comparison = compare_survival_datasets (
138+ {"Exponential" : data_exponential , "Weibull" : data_weibull }
139+ )
134140print (comparison )
135141
136142# Show plots if running interactively
0 commit comments