Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,28 @@ poetry install

## 🧪 Example

```python
from gen_surv.cphm import gen_cphm

df = gen_cphm(
n=100,
model_cens="uniform",
cens_par=1.0,
beta=0.5,
covar=2.0
)
print(df.head())
```

```python
from gen_surv import generate

df = generate(
model="cphm",
n=100,
model_cens="uniform",
cens_par=1.0,
beta=0.5,
covar=2.0
)
# CPHM
generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0)

# AFT Log-Normal
generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0)

# CMM
generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0,
qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0])

# TDCM
generate(model="tdcm", n=100, dist="weibull", corr=0.5,
dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0,
beta=[0.1, 0.2, 0.3], lam=1.0)

print(df.head())
# THMM
generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]],
emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]},
p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0)
```

## 🔧 Available Generators
Expand Down
33 changes: 18 additions & 15 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ It includes generators for:
- **Continuous-Time Markov Models (CMM)**
- **Time-Dependent Covariate Models (TDCM)**
- **Time-Homogeneous Hidden Markov Models (THMM)**
- **Accelerated Failure Time (AFT) Log-Normal Models**

---

Expand All @@ -25,25 +26,27 @@ theory
# 🚀 Usage Example

```python
from gen_surv.cphm import gen_cphm
from gen_surv import generate

df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0)
print(df.head())
```
# CPHM
generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0)

```python
from gen_surv import generate
# AFT Log-Normal
generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0)

# CMM
generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0,
qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0])

df = generate(
model="cphm",
n=100,
model_cens="uniform",
cens_par=1.0,
beta=0.5,
covar=2.0
)
# TDCM
generate(model="tdcm", n=100, dist="weibull", corr=0.5,
dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0,
beta=[0.1, 0.2, 0.3], lam=1.0)

print(df.head())
# THMM
generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]],
emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]},
p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0)
```

## 🔗 Project Links
Expand Down
17 changes: 17 additions & 0 deletions examples/run_aft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from gen_surv.interface import generate

# Generate synthetic survival data using Log-Normal AFT model
df = generate(
model="aft_ln",
n=100,
beta=[0.5, -0.3],
sigma=1.0,
model_cens="exponential",
cens_par=3.0,
seed=123
)

print(df.head())
17 changes: 17 additions & 0 deletions examples/run_cmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from gen_surv import generate

df = generate(
model="cmm",
n=100,
model_cens="exponential",
cens_par=2.0,
qmat=[[0, 0.1], [0.05, 0]],
p0=[1.0, 0.0],
seed=42
)

print(df.head())
17 changes: 17 additions & 0 deletions examples/run_cphm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from gen_surv import generate

df = generate(
model="cphm",
n=100,
model_cens="uniform",
cens_par=1.0,
beta=0.5,
covar=2.0,
seed=42
)

print(df.head())
20 changes: 20 additions & 0 deletions examples/run_tdcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from gen_surv import generate

df = generate(
model="tdcm",
n=100,
dist="weibull",
corr=0.5,
dist_par=[1, 2, 1, 2],
model_cens="uniform",
cens_par=1.0,
beta=[0.1, 0.2, 0.3],
lam=1.0,
seed=42
)

print(df.head())
18 changes: 18 additions & 0 deletions examples/run_thmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from gen_surv import generate

df = generate(
model="thmm",
n=100,
qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]],
emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]},
p0=[1.0, 0.0, 0.0],
model_cens="exponential",
cens_par=3.0,
seed=42
)

print(df.head())
1 change: 1 addition & 0 deletions stubs/gen_surv/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .interface import generate as generate
3 changes: 3 additions & 0 deletions stubs/gen_surv/aft.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from _typeshed import Incomplete

def gen_aft_log_normal(n, beta, sigma, model_cens, cens_par, seed: Incomplete | None = None): ...
1 change: 1 addition & 0 deletions stubs/gen_surv/bivariate.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
def sample_bivariate_distribution(n, dist, corr, dist_par): ...
4 changes: 4 additions & 0 deletions stubs/gen_surv/censoring.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

def runifcens(size: int, cens_par: float) -> np.ndarray: ...
def rexpocens(size: int, cens_par: float) -> np.ndarray: ...
5 changes: 5 additions & 0 deletions stubs/gen_surv/cmm.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens
from gen_surv.validate import validate_gen_cmm_inputs as validate_gen_cmm_inputs

def generate_event_times(z1: float, beta: list, rate: list) -> dict: ...
def gen_cmm(n, model_cens, cens_par, beta, covar, rate): ...
6 changes: 6 additions & 0 deletions stubs/gen_surv/cphm.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pandas as pd
from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens
from gen_surv.validate import validate_gen_cphm_inputs as validate_gen_cphm_inputs

def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): ...
def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: ...
7 changes: 7 additions & 0 deletions stubs/gen_surv/interface.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from gen_surv.aft import gen_aft_log_normal as gen_aft_log_normal
from gen_surv.cmm import gen_cmm as gen_cmm
from gen_surv.cphm import gen_cphm as gen_cphm
from gen_surv.tdcm import gen_tdcm as gen_tdcm
from gen_surv.thmm import gen_thmm as gen_thmm

def generate(model: str, **kwargs): ...
6 changes: 6 additions & 0 deletions stubs/gen_surv/tdcm.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from gen_surv.bivariate import sample_bivariate_distribution as sample_bivariate_distribution
from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens
from gen_surv.validate import validate_gen_tdcm_inputs as validate_gen_tdcm_inputs

def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): ...
def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): ...
5 changes: 5 additions & 0 deletions stubs/gen_surv/thmm.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens
from gen_surv.validate import validate_gen_thmm_inputs as validate_gen_thmm_inputs

def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: ...
def gen_thmm(n, model_cens, cens_par, beta, covar, rate): ...
6 changes: 6 additions & 0 deletions stubs/gen_surv/validate.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): ...
def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ...
def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, model_cens: str, cens_par: float, beta: list, lam: float): ...
def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ...
def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): ...
def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par) -> None: ...
23 changes: 23 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,26 @@ def publish(c):
@task
def clean(c):
c.run("rm -rf dist build docs/build .pytest_cache .mypy_cache coverage.xml .coverage stubs")

@task
def git_push(c):
"""
Stage all changes, prompt for a commit message, create a signed commit, and push.
"""
import getpass

c.run("git add .")

try:
# Prompt for a commit message
message = input("Enter commit message: ").strip()
if not message:
print("Aborting: empty commit message.")
return

sanitized_message = shlex.quote(message)
c.run(f"git commit -S -m {sanitized_message}")
c.run("git push")
except KeyboardInterrupt:
print("\nAborted by user.")

19 changes: 19 additions & 0 deletions tests/test_aft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pandas as pd
from gen_surv.aft import gen_aft_log_normal

def test_gen_aft_log_normal_runs():
df = gen_aft_log_normal(
n=10,
beta=[0.5, -0.2],
sigma=1.0,
model_cens="uniform",
cens_par=5.0,
seed=42
)
assert isinstance(df, pd.DataFrame)
assert not df.empty
assert "time" in df.columns
assert "status" in df.columns
assert "X0" in df.columns
assert "X1" in df.columns
assert set(df["status"].unique()).issubset({0, 1})