Skip to content

Commit c3571dd

Browse files
Fix black formatting (#43)
1 parent 5a184ca commit c3571dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+685
-572
lines changed

examples/run_aft.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import sys
21
import os
3-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
45
from gen_surv.interface import generate
56

67
# Generate synthetic survival data using Log-Normal AFT model
@@ -11,7 +12,7 @@
1112
sigma=1.0,
1213
model_cens="exponential",
1314
cens_par=3.0,
14-
seed=123
15+
seed=123,
1516
)
1617

1718
print(df.head())

examples/run_aft_weibull.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22
Example demonstrating Weibull AFT model and visualization capabilities.
33
"""
44

5-
import sys
65
import os
6+
import sys
7+
78
import matplotlib.pyplot as plt
89
import numpy as np
910
import pandas as pd
1011

11-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
1213

1314
from gen_surv import generate
1415
from gen_surv.visualization import (
15-
plot_survival_curve,
16-
plot_hazard_comparison,
16+
describe_survival,
1717
plot_covariate_effect,
18-
describe_survival
18+
plot_hazard_comparison,
19+
plot_survival_curve,
1920
)
2021

2122
# 1. Generate data from different models for comparison
@@ -28,7 +29,7 @@
2829
scale=2.0,
2930
model_cens="uniform",
3031
cens_par=5.0,
31-
seed=42
32+
seed=42,
3233
),
3334
"Weibull AFT (shape=1.0)": generate(
3435
model="aft_weibull",
@@ -38,7 +39,7 @@
3839
scale=2.0,
3940
model_cens="uniform",
4041
cens_par=5.0,
41-
seed=42
42+
seed=42,
4243
),
4344
"Weibull AFT (shape=2.0)": generate(
4445
model="aft_weibull",
@@ -48,8 +49,8 @@
4849
scale=2.0,
4950
model_cens="uniform",
5051
cens_par=5.0,
51-
seed=42
52-
)
52+
seed=42,
53+
),
5354
}
5455

5556
# Print sample data
@@ -59,18 +60,15 @@
5960

6061
# 2. Compare survival curves from different models
6162
fig1, ax1 = plot_survival_curve(
62-
data=pd.concat(
63-
[df.assign(_model=name) for name, df in models.items()]
64-
),
63+
data=pd.concat([df.assign(_model=name) for name, df in models.items()]),
6564
group_col="_model",
66-
title="Comparing Survival Curves with Different Weibull Shapes"
65+
title="Comparing Survival Curves with Different Weibull Shapes",
6766
)
6867
plt.savefig("survival_curve_comparison.png", dpi=300, bbox_inches="tight")
6968

7069
# 3. Compare hazard functions
7170
fig2, ax2 = plot_hazard_comparison(
72-
models=models,
73-
title="Comparing Hazard Functions with Different Weibull Shapes"
71+
models=models, title="Comparing Hazard Functions with Different Weibull Shapes"
7472
)
7573
plt.savefig("hazard_comparison.png", dpi=300, bbox_inches="tight")
7674

@@ -79,7 +77,7 @@
7977
data=models["Weibull AFT (shape=2.0)"],
8078
covariate_col="X0",
8179
n_groups=3,
82-
title="Effect of X0 Covariate on Survival"
80+
title="Effect of X0 Covariate on Survival",
8381
)
8482
plt.savefig("covariate_effect.png", dpi=300, bbox_inches="tight")
8583

examples/run_cmm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import sys
21
import os
3-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
45

56
from gen_surv import generate
67

@@ -11,7 +12,7 @@
1112
cens_par=2.0,
1213
qmat=[[0, 0.1], [0.05, 0]],
1314
p0=[1.0, 0.0],
14-
seed=42
15+
seed=42,
1516
)
1617

1718
print(df.head())

examples/run_competing_risks.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,64 @@
22
Example demonstrating the Competing Risks models and visualization.
33
"""
44

5-
import sys
65
import os
6+
import sys
7+
78
import matplotlib.pyplot as plt
89
import numpy as np
910
import pandas as pd
1011

11-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
1213

1314
from gen_surv import generate
14-
from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull, cause_specific_cumulative_incidence
15-
from gen_surv.summary import summarize_survival_dataset, compare_survival_datasets
15+
from gen_surv.competing_risks import (
16+
cause_specific_cumulative_incidence,
17+
gen_competing_risks,
18+
gen_competing_risks_weibull,
19+
)
20+
from gen_surv.summary import compare_survival_datasets, summarize_survival_dataset
1621

1722

1823
def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, 6)):
1924
"""Plot the cause-specific cumulative incidence functions."""
2025
if time_points is None:
2126
max_time = df["time"].max()
2227
time_points = np.linspace(0, max_time, 100)
23-
28+
2429
# Get unique causes (excluding censoring)
2530
causes = sorted([c for c in df["status"].unique() if c > 0])
26-
31+
2732
# Create the plot
2833
fig, ax = plt.subplots(figsize=figsize)
29-
34+
3035
for cause in causes:
3136
cif = cause_specific_cumulative_incidence(df, time_points, cause=cause)
3237
ax.plot(cif["time"], cif["incidence"], label=f"Cause {cause}")
33-
38+
3439
# Add overlay showing number of subjects at each time
3540
time_bins = np.linspace(0, df["time"].max(), 10)
3641
event_counts = np.histogram(df.loc[df["status"] > 0, "time"], bins=time_bins)[0]
37-
42+
3843
# Add a secondary y-axis for event counts
3944
ax2 = ax.twinx()
40-
ax2.bar(time_bins[:-1], event_counts, width=time_bins[1]-time_bins[0],
41-
alpha=0.2, color='gray', align='edge')
42-
ax2.set_ylabel('Number of events')
45+
ax2.bar(
46+
time_bins[:-1],
47+
event_counts,
48+
width=time_bins[1] - time_bins[0],
49+
alpha=0.2,
50+
color="gray",
51+
align="edge",
52+
)
53+
ax2.set_ylabel("Number of events")
4354
ax2.grid(False)
44-
55+
4556
# Format the main plot
4657
ax.set_xlabel("Time")
4758
ax.set_ylabel("Cumulative Incidence")
4859
ax.set_title("Cause-Specific Cumulative Incidence Functions")
4960
ax.legend()
5061
ax.grid(alpha=0.3)
51-
62+
5263
return fig, ax
5364

5465

@@ -61,7 +72,7 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
6172
betas=[[0.8, -0.5], [0.2, 0.7]],
6273
model_cens="uniform",
6374
cens_par=2.0,
64-
seed=42
75+
seed=42,
6576
)
6677

6778
# 2. Generate data with Weibull hazards (different shapes)
@@ -74,7 +85,7 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
7485
betas=[[0.8, -0.5], [0.2, 0.7]],
7586
model_cens="uniform",
7687
cens_par=2.0,
77-
seed=42
88+
seed=42,
7889
)
7990

8091
# 3. Print summary statistics for both datasets
@@ -96,17 +107,13 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
96107
time_points = np.linspace(0, 5, 100)
97108

98109
fig1, ax1 = plot_cause_specific_cumulative_incidence(
99-
data_exponential,
100-
time_points=time_points,
101-
figsize=(10, 6)
110+
data_exponential, time_points=time_points, figsize=(10, 6)
102111
)
103112
plt.title("Cumulative Incidence Functions (Exponential Hazards)")
104113
plt.savefig("cr_exponential_cif.png", dpi=300, bbox_inches="tight")
105114

106115
fig2, ax2 = plot_cause_specific_cumulative_incidence(
107-
data_weibull,
108-
time_points=time_points,
109-
figsize=(10, 6)
116+
data_weibull, time_points=time_points, figsize=(10, 6)
110117
)
111118
plt.title("Cumulative Incidence Functions (Weibull Hazards)")
112119
plt.savefig("cr_weibull_cif.png", dpi=300, bbox_inches="tight")
@@ -121,16 +128,15 @@ def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10,
121128
betas=[[0.8, -0.5], [0.2, 0.7]],
122129
model_cens="uniform",
123130
cens_par=2.0,
124-
seed=42
131+
seed=42,
125132
)
126133
print(data_unified.head())
127134

128135
# 7. Compare datasets
129136
print("\nComparing datasets:")
130-
comparison = compare_survival_datasets({
131-
"Exponential": data_exponential,
132-
"Weibull": data_weibull
133-
})
137+
comparison = compare_survival_datasets(
138+
{"Exponential": data_exponential, "Weibull": data_weibull}
139+
)
134140
print(comparison)
135141

136142
# Show plots if running interactively

examples/run_cphm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import sys
21
import os
3-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
45

56
from gen_surv import generate
67

@@ -11,7 +12,7 @@
1112
cens_par=1.0,
1213
beta=0.5,
1314
covariate_range=2.0,
14-
seed=42
15+
seed=42,
1516
)
1617

1718
print(df.head())

examples/run_tdcm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import sys
21
import os
3-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
45

56
from gen_surv import generate
67

@@ -14,7 +15,7 @@
1415
cens_par=1.0,
1516
beta=[0.1, 0.2, 0.3],
1617
lam=1.0,
17-
seed=42
18+
seed=42,
1819
)
1920

2021
print(df.head())

examples/run_thmm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import sys
21
import os
3-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
45

56
from gen_surv import generate
67

@@ -12,7 +13,7 @@
1213
p0=[1.0, 0.0, 0.0],
1314
model_cens="exponential",
1415
cens_par=3.0,
15-
seed=42
16+
seed=42,
1617
)
1718

1819
print(df.head())

gen_surv/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,32 @@
55

66
from importlib.metadata import PackageNotFoundError, version
77

8-
# Main interface
9-
from .interface import generate
8+
from .aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull
109

11-
# Individual generators
12-
from .cphm import gen_cphm
10+
# Helper functions
11+
from .bivariate import sample_bivariate_distribution
12+
from .censoring import rexpocens, runifcens
1313
from .cmm import gen_cmm
14-
from .tdcm import gen_tdcm
15-
from .thmm import gen_thmm
16-
from .aft import gen_aft_log_normal, gen_aft_weibull, gen_aft_log_logistic
1714
from .competing_risks import gen_competing_risks, gen_competing_risks_weibull
18-
from .mixture import gen_mixture_cure, cure_fraction_estimate
19-
from .piecewise import gen_piecewise_exponential
15+
16+
# Individual generators
17+
from .cphm import gen_cphm
2018
from .export import export_dataset
2119

22-
# Helper functions
23-
from .bivariate import sample_bivariate_distribution
24-
from .censoring import runifcens, rexpocens
20+
# Main interface
21+
from .interface import generate
22+
from .mixture import cure_fraction_estimate, gen_mixture_cure
23+
from .piecewise import gen_piecewise_exponential
24+
from .tdcm import gen_tdcm
25+
from .thmm import gen_thmm
2526

2627
# Visualization tools (requires matplotlib and lifelines)
2728
try:
2829
from .visualization import (
29-
plot_survival_curve,
30-
plot_hazard_comparison,
31-
plot_covariate_effect,
3230
describe_survival,
31+
plot_covariate_effect,
32+
plot_hazard_comparison,
33+
plot_survival_curve,
3334
)
3435

3536
_has_visualization = True

0 commit comments

Comments
 (0)