diff --git a/README.md b/README.md index a3da27a..f25fb81 100644 --- a/README.md +++ b/README.md @@ -30,32 +30,28 @@ poetry install ## ๐Ÿงช Example -```python -from gen_surv.cphm import gen_cphm - -df = gen_cphm( - n=100, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covar=2.0 -) -print(df.head()) -``` - ```python from gen_surv import generate -df = generate( - model="cphm", - n=100, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covar=2.0 -) +# CPHM +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + +# AFT Log-Normal +generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) + +# CMM +generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, + qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) + +# TDCM +generate(model="tdcm", n=100, dist="weibull", corr=0.5, + dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, + beta=[0.1, 0.2, 0.3], lam=1.0) -print(df.head()) +# THMM +generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], + emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, + p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) ``` ## ๐Ÿ”ง Available Generators diff --git a/docs/source/index.md b/docs/source/index.md index 51297ed..bba60fd 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -8,6 +8,7 @@ It includes generators for: - **Continuous-Time Markov Models (CMM)** - **Time-Dependent Covariate Models (TDCM)** - **Time-Homogeneous Hidden Markov Models (THMM)** +- **Accelerated Failure Time (AFT) Log-Normal Models** --- @@ -25,25 +26,27 @@ theory # ๐Ÿš€ Usage Example ```python -from gen_surv.cphm import gen_cphm +from gen_surv import generate -df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) -print(df.head()) -``` +# CPHM +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) -```python -from gen_surv import generate +# AFT Log-Normal +generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) + +# CMM +generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, + qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) -df = generate( - model="cphm", - n=100, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covar=2.0 -) +# TDCM +generate(model="tdcm", n=100, dist="weibull", corr=0.5, + dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, + beta=[0.1, 0.2, 0.3], lam=1.0) -print(df.head()) +# THMM +generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], + emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, + p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) ``` ## ๐Ÿ”— Project Links diff --git a/examples/run_aft.py b/examples/run_aft.py new file mode 100644 index 0000000..5ba6388 --- /dev/null +++ b/examples/run_aft.py @@ -0,0 +1,17 @@ +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from gen_surv.interface import generate + +# Generate synthetic survival data using Log-Normal AFT model +df = generate( + model="aft_ln", + n=100, + beta=[0.5, -0.3], + sigma=1.0, + model_cens="exponential", + cens_par=3.0, + seed=123 +) + +print(df.head()) diff --git a/examples/run_cmm.py b/examples/run_cmm.py new file mode 100644 index 0000000..590b7a1 --- /dev/null +++ b/examples/run_cmm.py @@ -0,0 +1,17 @@ +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from gen_surv import generate + +df = generate( + model="cmm", + n=100, + model_cens="exponential", + cens_par=2.0, + qmat=[[0, 0.1], [0.05, 0]], + p0=[1.0, 0.0], + seed=42 +) + +print(df.head()) diff --git a/examples/run_cphm.py b/examples/run_cphm.py new file mode 100644 index 0000000..c02b01b --- /dev/null +++ b/examples/run_cphm.py @@ -0,0 +1,17 @@ +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from gen_surv import generate + +df = generate( + model="cphm", + n=100, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covar=2.0, + seed=42 +) + +print(df.head()) diff --git a/examples/run_tdcm.py b/examples/run_tdcm.py new file mode 100644 index 0000000..dd5204c --- /dev/null +++ b/examples/run_tdcm.py @@ -0,0 +1,20 @@ +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from gen_surv import generate + +df = generate( + model="tdcm", + n=100, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + seed=42 +) + +print(df.head()) diff --git a/examples/run_thmm.py b/examples/run_thmm.py new file mode 100644 index 0000000..73721ad --- /dev/null +++ b/examples/run_thmm.py @@ -0,0 +1,18 @@ +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from gen_surv import generate + +df = generate( + model="thmm", + n=100, + qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], + emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, + p0=[1.0, 0.0, 0.0], + model_cens="exponential", + cens_par=3.0, + seed=42 +) + +print(df.head()) diff --git a/stubs/gen_surv/__init__.pyi b/stubs/gen_surv/__init__.pyi new file mode 100644 index 0000000..1bd66c2 --- /dev/null +++ b/stubs/gen_surv/__init__.pyi @@ -0,0 +1 @@ +from .interface import generate as generate diff --git a/stubs/gen_surv/aft.pyi b/stubs/gen_surv/aft.pyi new file mode 100644 index 0000000..ca00051 --- /dev/null +++ b/stubs/gen_surv/aft.pyi @@ -0,0 +1,3 @@ +from _typeshed import Incomplete + +def gen_aft_log_normal(n, beta, sigma, model_cens, cens_par, seed: Incomplete | None = None): ... diff --git a/stubs/gen_surv/bivariate.pyi b/stubs/gen_surv/bivariate.pyi new file mode 100644 index 0000000..23596f2 --- /dev/null +++ b/stubs/gen_surv/bivariate.pyi @@ -0,0 +1 @@ +def sample_bivariate_distribution(n, dist, corr, dist_par): ... diff --git a/stubs/gen_surv/censoring.pyi b/stubs/gen_surv/censoring.pyi new file mode 100644 index 0000000..bb9ffd0 --- /dev/null +++ b/stubs/gen_surv/censoring.pyi @@ -0,0 +1,4 @@ +import numpy as np + +def runifcens(size: int, cens_par: float) -> np.ndarray: ... +def rexpocens(size: int, cens_par: float) -> np.ndarray: ... diff --git a/stubs/gen_surv/cmm.pyi b/stubs/gen_surv/cmm.pyi new file mode 100644 index 0000000..44a810e --- /dev/null +++ b/stubs/gen_surv/cmm.pyi @@ -0,0 +1,5 @@ +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_cmm_inputs as validate_gen_cmm_inputs + +def generate_event_times(z1: float, beta: list, rate: list) -> dict: ... +def gen_cmm(n, model_cens, cens_par, beta, covar, rate): ... diff --git a/stubs/gen_surv/cphm.pyi b/stubs/gen_surv/cphm.pyi new file mode 100644 index 0000000..cac39b3 --- /dev/null +++ b/stubs/gen_surv/cphm.pyi @@ -0,0 +1,6 @@ +import pandas as pd +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_cphm_inputs as validate_gen_cphm_inputs + +def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): ... +def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: ... diff --git a/stubs/gen_surv/interface.pyi b/stubs/gen_surv/interface.pyi new file mode 100644 index 0000000..a6a0a85 --- /dev/null +++ b/stubs/gen_surv/interface.pyi @@ -0,0 +1,7 @@ +from gen_surv.aft import gen_aft_log_normal as gen_aft_log_normal +from gen_surv.cmm import gen_cmm as gen_cmm +from gen_surv.cphm import gen_cphm as gen_cphm +from gen_surv.tdcm import gen_tdcm as gen_tdcm +from gen_surv.thmm import gen_thmm as gen_thmm + +def generate(model: str, **kwargs): ... diff --git a/stubs/gen_surv/tdcm.pyi b/stubs/gen_surv/tdcm.pyi new file mode 100644 index 0000000..15e7dd0 --- /dev/null +++ b/stubs/gen_surv/tdcm.pyi @@ -0,0 +1,6 @@ +from gen_surv.bivariate import sample_bivariate_distribution as sample_bivariate_distribution +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_tdcm_inputs as validate_gen_tdcm_inputs + +def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): ... +def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): ... diff --git a/stubs/gen_surv/thmm.pyi b/stubs/gen_surv/thmm.pyi new file mode 100644 index 0000000..915c59d --- /dev/null +++ b/stubs/gen_surv/thmm.pyi @@ -0,0 +1,5 @@ +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_thmm_inputs as validate_gen_thmm_inputs + +def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: ... +def gen_thmm(n, model_cens, cens_par, beta, covar, rate): ... diff --git a/stubs/gen_surv/validate.pyi b/stubs/gen_surv/validate.pyi new file mode 100644 index 0000000..9ba0053 --- /dev/null +++ b/stubs/gen_surv/validate.pyi @@ -0,0 +1,6 @@ +def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): ... +def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ... +def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, model_cens: str, cens_par: float, beta: list, lam: float): ... +def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ... +def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): ... +def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par) -> None: ... diff --git a/tasks.py b/tasks.py index f02b94e..a3d64b1 100644 --- a/tasks.py +++ b/tasks.py @@ -23,3 +23,26 @@ def publish(c): @task def clean(c): c.run("rm -rf dist build docs/build .pytest_cache .mypy_cache coverage.xml .coverage stubs") + +@task +def git_push(c): + """ + Stage all changes, prompt for a commit message, create a signed commit, and push. + """ + import getpass + + c.run("git add .") + + try: + # Prompt for a commit message + message = input("Enter commit message: ").strip() + if not message: + print("Aborting: empty commit message.") + return + + sanitized_message = shlex.quote(message) + c.run(f"git commit -S -m {sanitized_message}") + c.run("git push") + except KeyboardInterrupt: + print("\nAborted by user.") + diff --git a/tests/test_aft.py b/tests/test_aft.py new file mode 100644 index 0000000..2688144 --- /dev/null +++ b/tests/test_aft.py @@ -0,0 +1,19 @@ +import pandas as pd +from gen_surv.aft import gen_aft_log_normal + +def test_gen_aft_log_normal_runs(): + df = gen_aft_log_normal( + n=10, + beta=[0.5, -0.2], + sigma=1.0, + model_cens="uniform", + cens_par=5.0, + seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1}) \ No newline at end of file