Skip to content

Commit 47e4d7f

Browse files
committed
feat: Add competing risks models and enhance data visualization
1 parent b3abc8f commit 47e4d7f

File tree

16 files changed

+2896
-90
lines changed

16 files changed

+2896
-90
lines changed

.github/workflows/bump-version.yml

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ jobs:
2222
with:
2323
python-version: "3.11"
2424

25-
- run: pip install python-semantic-release
25+
- name: Install Poetry
26+
run: pip install poetry
27+
28+
- name: Install python-semantic-release
29+
run: pip install python-semantic-release
2630

2731
- name: Configure Git
2832
run: |
@@ -32,36 +36,10 @@ jobs:
3236
- name: Run Semantic Release
3337
env:
3438
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
35-
run: semantic-release version
39+
run: |
40+
# Run semantic-release to get the next version
41+
semantic-release version
3642
3743
- name: Push changes
3844
run: |
3945
git push --follow-tags
40-
- name: "Install Poetry"
41-
run: pip install poetry
42-
- name: "Determine version bump type"
43-
run: |
44-
git fetch --tags
45-
# This defaults to a patch type, unless a feature commit was pushed, then set type to minor
46-
LAST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1))
47-
LAST_COMMIT=$(git log -1 --format='%H')
48-
echo "Last git tag: $LAST_TAG"
49-
echo "Last git commit: $LAST_COMMIT"
50-
echo "Commits:"
51-
git log --no-merges --pretty=oneline $LAST_TAG...$LAST_COMMIT
52-
git log --no-merges --pretty=format:"%s" $LAST_TAG...$LAST_COMMIT | grep -q ^feat: && BUMP_TYPE="minor" || BUMP_TYPE="patch"
53-
echo "Version bump type: $BUMP_TYPE"
54-
echo "BUMP_TYPE=$BUMP_TYPE" >> $GITHUB_ENV
55-
env:
56-
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
57-
- name: "Version bump"
58-
run: |
59-
poetry version $BUMP_TYPE
60-
- name: "Push new version"
61-
run: |
62-
git add pyproject.toml
63-
git commit -m "Update version to $(poetry version -s)"
64-
git pull --ff-only origin main
65-
git push origin main --follow-tags
66-
env:
67-
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

.github/workflows/ci.yml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
test:
11+
name: Test with Python ${{ matrix.python-version }}
12+
runs-on: ubuntu-latest
13+
strategy:
14+
matrix:
15+
python-version: ["3.9", "3.10", "3.11"]
16+
17+
steps:
18+
- name: Checkout code
19+
uses: actions/checkout@v4
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
26+
- name: Install Poetry
27+
run: |
28+
curl -sSL https://install.python-poetry.org | python3 -
29+
echo "$HOME/.local/bin" >> $GITHUB_PATH
30+
31+
- name: Install dependencies
32+
run: poetry install
33+
34+
- name: Run tests with coverage
35+
run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term
36+
37+
- name: Upload coverage to Codecov
38+
uses: codecov/codecov-action@v5
39+
with:
40+
files: coverage.xml
41+
token: ${{ secrets.CODECOV_TOKEN }} # optional if public repo
42+
43+
lint:
44+
name: Code Quality
45+
runs-on: ubuntu-latest
46+
47+
steps:
48+
- name: Checkout code
49+
uses: actions/checkout@v4
50+
51+
- name: Set up Python
52+
uses: actions/setup-python@v5
53+
with:
54+
python-version: "3.11"
55+
56+
- name: Install Poetry
57+
run: |
58+
curl -sSL https://install.python-poetry.org | python3 -
59+
echo "$HOME/.local/bin" >> $GITHUB_PATH
60+
61+
- name: Install dependencies
62+
run: poetry install
63+
64+
- name: Run black
65+
run: poetry run black --check gen_surv tests examples
66+
67+
- name: Run isort
68+
run: poetry run isort --check gen_surv tests examples
69+
70+
- name: Run flake8
71+
run: poetry run flake8 gen_surv tests examples
72+
73+
- name: Run mypy
74+
run: poetry run mypy gen_surv

examples/run_aft_weibull.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Example demonstrating Weibull AFT model and visualization capabilities.
3+
"""
4+
5+
import sys
6+
import os
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import pandas as pd
10+
11+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12+
13+
from gen_surv import generate
14+
from gen_surv.visualization import (
15+
plot_survival_curve,
16+
plot_hazard_comparison,
17+
plot_covariate_effect,
18+
describe_survival
19+
)
20+
21+
# 1. Generate data from different models for comparison
22+
models = {
23+
"Weibull AFT (shape=0.5)": generate(
24+
model="aft_weibull",
25+
n=200,
26+
beta=[0.5, -0.3],
27+
shape=0.5, # Decreasing hazard
28+
scale=2.0,
29+
model_cens="uniform",
30+
cens_par=5.0,
31+
seed=42
32+
),
33+
"Weibull AFT (shape=1.0)": generate(
34+
model="aft_weibull",
35+
n=200,
36+
beta=[0.5, -0.3],
37+
shape=1.0, # Constant hazard
38+
scale=2.0,
39+
model_cens="uniform",
40+
cens_par=5.0,
41+
seed=42
42+
),
43+
"Weibull AFT (shape=2.0)": generate(
44+
model="aft_weibull",
45+
n=200,
46+
beta=[0.5, -0.3],
47+
shape=2.0, # Increasing hazard
48+
scale=2.0,
49+
model_cens="uniform",
50+
cens_par=5.0,
51+
seed=42
52+
)
53+
}
54+
55+
# Print sample data
56+
print("Sample data from Weibull AFT model (shape=2.0):")
57+
print(models["Weibull AFT (shape=2.0)"].head())
58+
print("\n")
59+
60+
# 2. Compare survival curves from different models
61+
fig1, ax1 = plot_survival_curve(
62+
data=pd.concat(
63+
[df.assign(_model=name) for name, df in models.items()]
64+
),
65+
group_col="_model",
66+
title="Comparing Survival Curves with Different Weibull Shapes"
67+
)
68+
plt.savefig("survival_curve_comparison.png", dpi=300, bbox_inches="tight")
69+
70+
# 3. Compare hazard functions
71+
fig2, ax2 = plot_hazard_comparison(
72+
models=models,
73+
title="Comparing Hazard Functions with Different Weibull Shapes"
74+
)
75+
plt.savefig("hazard_comparison.png", dpi=300, bbox_inches="tight")
76+
77+
# 4. Visualize covariate effect on survival
78+
fig3, ax3 = plot_covariate_effect(
79+
data=models["Weibull AFT (shape=2.0)"],
80+
covariate_col="X0",
81+
n_groups=3,
82+
title="Effect of X0 Covariate on Survival"
83+
)
84+
plt.savefig("covariate_effect.png", dpi=300, bbox_inches="tight")
85+
86+
# 5. Summary statistics
87+
for name, df in models.items():
88+
print(f"Summary for {name}:")
89+
summary = describe_survival(df)
90+
print(summary)
91+
print("\n")
92+
93+
print("Plots saved to current directory.")
94+
95+
# Show plots if running interactively
96+
if __name__ == "__main__":
97+
plt.show()

examples/run_competing_risks.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Example demonstrating the Competing Risks models and visualization.
3+
"""
4+
5+
import sys
6+
import os
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import pandas as pd
10+
11+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12+
13+
from gen_surv import generate
14+
from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull, cause_specific_cumulative_incidence
15+
from gen_surv.summary import summarize_survival_dataset, compare_survival_datasets
16+
17+
18+
def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, 6)):
19+
"""Plot the cause-specific cumulative incidence functions."""
20+
if time_points is None:
21+
max_time = df["time"].max()
22+
time_points = np.linspace(0, max_time, 100)
23+
24+
# Get unique causes (excluding censoring)
25+
causes = sorted([c for c in df["status"].unique() if c > 0])
26+
27+
# Create the plot
28+
fig, ax = plt.subplots(figsize=figsize)
29+
30+
for cause in causes:
31+
cif = cause_specific_cumulative_incidence(df, time_points, cause=cause)
32+
ax.plot(cif["time"], cif["incidence"], label=f"Cause {cause}")
33+
34+
# Add overlay showing number of subjects at each time
35+
time_bins = np.linspace(0, df["time"].max(), 10)
36+
event_counts = np.histogram(df.loc[df["status"] > 0, "time"], bins=time_bins)[0]
37+
38+
# Add a secondary y-axis for event counts
39+
ax2 = ax.twinx()
40+
ax2.bar(time_bins[:-1], event_counts, width=time_bins[1]-time_bins[0],
41+
alpha=0.2, color='gray', align='edge')
42+
ax2.set_ylabel('Number of events')
43+
ax2.grid(False)
44+
45+
# Format the main plot
46+
ax.set_xlabel("Time")
47+
ax.set_ylabel("Cumulative Incidence")
48+
ax.set_title("Cause-Specific Cumulative Incidence Functions")
49+
ax.legend()
50+
ax.grid(alpha=0.3)
51+
52+
return fig, ax
53+
54+
55+
# 1. Generate data with 2 competing risks
56+
print("Generating data with exponential hazards...")
57+
data_exponential = gen_competing_risks(
58+
n=500,
59+
n_risks=2,
60+
baseline_hazards=[0.5, 0.3],
61+
betas=[[0.8, -0.5], [0.2, 0.7]],
62+
model_cens="uniform",
63+
cens_par=2.0,
64+
seed=42
65+
)
66+
67+
# 2. Generate data with Weibull hazards (different shapes)
68+
print("Generating data with Weibull hazards...")
69+
data_weibull = gen_competing_risks_weibull(
70+
n=500,
71+
n_risks=2,
72+
shape_params=[0.8, 1.5], # Decreasing vs increasing hazard
73+
scale_params=[2.0, 3.0],
74+
betas=[[0.8, -0.5], [0.2, 0.7]],
75+
model_cens="uniform",
76+
cens_par=2.0,
77+
seed=42
78+
)
79+
80+
# 3. Print summary statistics for both datasets
81+
print("\nSummary of Exponential Hazards dataset:")
82+
summarize_survival_dataset(data_exponential)
83+
84+
print("\nSummary of Weibull Hazards dataset:")
85+
summarize_survival_dataset(data_weibull)
86+
87+
# 4. Compare event distributions
88+
print("\nEvent distribution (Exponential Hazards):")
89+
print(data_exponential["status"].value_counts())
90+
91+
print("\nEvent distribution (Weibull Hazards):")
92+
print(data_weibull["status"].value_counts())
93+
94+
# 5. Plot cause-specific cumulative incidence functions
95+
print("\nPlotting cumulative incidence functions...")
96+
time_points = np.linspace(0, 5, 100)
97+
98+
fig1, ax1 = plot_cause_specific_cumulative_incidence(
99+
data_exponential,
100+
time_points=time_points,
101+
figsize=(10, 6)
102+
)
103+
plt.title("Cumulative Incidence Functions (Exponential Hazards)")
104+
plt.savefig("cr_exponential_cif.png", dpi=300, bbox_inches="tight")
105+
106+
fig2, ax2 = plot_cause_specific_cumulative_incidence(
107+
data_weibull,
108+
time_points=time_points,
109+
figsize=(10, 6)
110+
)
111+
plt.title("Cumulative Incidence Functions (Weibull Hazards)")
112+
plt.savefig("cr_weibull_cif.png", dpi=300, bbox_inches="tight")
113+
114+
# 6. Demonstrate using the unified generate() interface
115+
print("\nUsing the unified generate() interface:")
116+
data_unified = generate(
117+
model="competing_risks",
118+
n=100,
119+
n_risks=2,
120+
baseline_hazards=[0.5, 0.3],
121+
betas=[[0.8, -0.5], [0.2, 0.7]],
122+
model_cens="uniform",
123+
cens_par=2.0,
124+
seed=42
125+
)
126+
print(data_unified.head())
127+
128+
# 7. Compare datasets
129+
print("\nComparing datasets:")
130+
comparison = compare_survival_datasets({
131+
"Exponential": data_exponential,
132+
"Weibull": data_weibull
133+
})
134+
print(comparison)
135+
136+
# Show plots if running interactively
137+
if __name__ == "__main__":
138+
plt.show()

0 commit comments

Comments
 (0)