-
Notifications
You must be signed in to change notification settings - Fork 0
feat: implement THMM data generator and finalize full model suite #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # Poetry virtualenvs | ||
| .env | ||
| .venv | ||
| poetry.lock | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
|
|
||
| # Pytest | ||
| .pytest_cache/ | ||
|
|
||
| # Jupyter Notebook checkpoints | ||
| .ipynb_checkpoints | ||
|
|
||
| # PyCharm | ||
| .idea/ | ||
|
|
||
| # VSCode | ||
| .vscode/ | ||
|
|
||
| # MacOS | ||
| .DS_Store | ||
|
|
||
| # System files | ||
| Thumbs.db | ||
| ehthumbs.db | ||
|
|
||
| # Build artifacts | ||
| build/ | ||
| dist/ | ||
| *.egg-info/ | ||
|
|
||
| # Temporary | ||
| *.log | ||
| *.tmp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| MIT License | ||
|
|
||
| Copyright (c) 2025 [Diogo Ribeiro] | ||
|
|
||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| of this software and associated documentation files (the "Software"), to deal | ||
| in the Software without restriction, including without limitation the rights | ||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| copies of the Software, and to permit persons to whom the Software is | ||
| furnished to do so, subject to the following conditions: | ||
|
|
||
| The above copyright notice and this permission notice shall be included in all | ||
| copies or substantial portions of the Software. | ||
|
|
||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| SOFTWARE. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,62 @@ | ||
| # private_repo_template | ||
| # gen_surv | ||
|
|
||
| **gen_surv** is a Python package for simulating survival data under a variety of models, inspired by the R package [`genSurv`](https://cran.r-project.org/package=genSurv). It supports data generation for: | ||
|
|
||
| - Cox Proportional Hazards Models (CPHM) | ||
| - Continuous-Time Markov Models (CMM) | ||
| - Time-Dependent Covariate Models (TDCM) | ||
| - Time-Homogeneous Hidden Markov Models (THMM) | ||
|
|
||
| --- | ||
|
|
||
| ## 📦 Installation | ||
|
|
||
| ```bash | ||
| poetry install | ||
| ``` | ||
| ## ✨ Features | ||
|
|
||
| - Consistent interface across models | ||
| - Censoring support (`uniform` or `exponential`) | ||
| - Easy integration with `pandas` and `NumPy` | ||
| - Suitable for benchmarking survival algorithms and teaching | ||
|
|
||
| ## 🧪 Example | ||
|
|
||
| ```python | ||
| from gen_surv.cphm import gen_cphm | ||
|
|
||
| df = gen_cphm( | ||
| n=100, | ||
| model_cens="uniform", | ||
| cens_par=1.0, | ||
| beta=0.5, | ||
| covar=2.0 | ||
| ) | ||
| print(df.head()) | ||
| ``` | ||
|
|
||
| ## 🔧 Available Generators | ||
|
|
||
| | Function | Description | | ||
| |--------------|--------------------------------------------| | ||
| | `gen_cphm()` | Cox Proportional Hazards Model | | ||
| | `gen_cmm()` | Continuous-Time Multi-State Markov Model | | ||
| | `gen_tdcm()` | Time-Dependent Covariate Model | | ||
| | `gen_thmm()` | Time-Homogeneous Markov Model | | ||
|
|
||
|
|
||
| ```text | ||
| genSurvPy/ | ||
| gen_surv/ | ||
| ├── gen_surv/ | ||
| │ ├── __init__.py | ||
| │ ├── cphm.py ← put CPHM logic here | ||
| │ ├── validate.py ← validation functions here | ||
| │ ├── censoring.py ← censoring functions here | ||
| │ └── utils.py ← for any shared tools | ||
| ├── tests/ | ||
| │ ├── __init__.py | ||
| │ └── test_cphm.py ← tests for CPHM | ||
| ├── pyproject.toml | ||
| └── README.md ← rename README.rst if you prefer | ||
| LICENSE | ||
| ``` | ||
| ├── cphm.py | ||
| ├── cmm.py | ||
| ├── tdcm.py | ||
| ├── thmm.py | ||
| ├── censoring.py | ||
| ├── validate.py | ||
| ``` | ||
|
|
||
| ## 🧠 License | ||
|
|
||
| MIT License. See [LICENSE](LICENSE) for details. |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| import numpy as np | ||
|
|
||
| def runifcens(size: int, cens_par: float) -> np.ndarray: ... | ||
| def rexpocens(size: int, cens_par: float) -> np.ndarray: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens | ||
| from gen_surv.validate import validate_gen_cmm_inputs as validate_gen_cmm_inputs | ||
|
|
||
| def generate_event_times(z1: float, beta: list, rate: list) -> dict: ... | ||
| def gen_cmm(n, model_cens, cens_par, beta, covar, rate): ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| import pandas as pd | ||
| from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens | ||
| from gen_surv.validate import validate_gen_cphm_inputs as validate_gen_cphm_inputs | ||
|
|
||
| def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): ... | ||
| def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens | ||
| from gen_surv.validate import validate_gen_tdcm_inputs as validate_gen_tdcm_inputs | ||
|
|
||
| def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): ... | ||
| def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens | ||
| from gen_surv.validate import validate_gen_thmm_inputs as validate_gen_thmm_inputs | ||
|
|
||
| def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: ... | ||
| def gen_thmm(n, model_cens, cens_par, beta, covar, rate): ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): ... | ||
| def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ... | ||
| def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, model_cens: str, cens_par: float, beta: list, lam: float): ... | ||
| def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| from gen_surv.validate import validate_gen_tdcm_inputs | ||
| from gen_surv.censoring import runifcens, rexpocens | ||
|
|
||
| def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): | ||
| """ | ||
| Generate censored TDCM observations. | ||
|
|
||
| Parameters: | ||
| - n (int): Number of individuals | ||
| - dist_par (list): Not directly used here (kept for API compatibility) | ||
| - model_cens (str): "uniform" or "exponential" | ||
| - cens_par (float): Parameter for the censoring model | ||
| - beta (list): Length-2 list of regression coefficients | ||
| - lam (float): Rate parameter | ||
| - b (np.ndarray): Covariate matrix with 2 columns [., z1] | ||
|
|
||
| Returns: | ||
| - np.ndarray: Shape (n, 6) with columns: | ||
| [id, start, stop, status, covariate1 (z1), covariate2 (z2)] | ||
| """ | ||
| rfunc = runifcens if model_cens == "uniform" else rexpocens | ||
| observations = np.zeros((n, 6)) | ||
|
|
||
| for k in range(n): | ||
| z1 = b[k, 1] | ||
| c = rfunc(1, cens_par)[0] | ||
| u = np.random.uniform() | ||
|
|
||
| # Determine path based on u threshold | ||
| threshold = 1 - np.exp(-lam * b[k, 0] * np.exp(beta[0] * z1)) | ||
| if u < threshold: | ||
| t = -np.log(1 - u) / (lam * np.exp(beta[0] * z1)) | ||
| z2 = 0 | ||
| else: | ||
| t = ( | ||
| -np.log(1 - u) | ||
| + lam * b[k, 0] * np.exp(beta[0] * z1) * (1 - np.exp(beta[2])) | ||
| ) / (lam * np.exp(beta[0] * z1 + beta[1])) | ||
| z2 = 1 | ||
|
|
||
| time = min(t, c) | ||
| status = int(t <= c) | ||
|
|
||
| observations[k] = [k + 1, 0, time, status, z1, z2] | ||
|
|
||
| return observations | ||
|
|
||
|
|
||
| def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): | ||
| """ | ||
| Generate TDCM (Time-Dependent Covariate Model) survival data. | ||
|
|
||
| Parameters: | ||
| - n (int): Number of individuals. | ||
| - dist (str): "weibull" or "exponential". | ||
| - corr (float): Correlation coefficient. | ||
| - dist_par (list): Distribution parameters. | ||
| - model_cens (str): "uniform" or "exponential". | ||
| - cens_par (float): Censoring parameter. | ||
| - beta (list): Length-2 regression coefficients. | ||
| - lam (float): Lambda rate parameter. | ||
|
|
||
| Returns: | ||
| - pd.DataFrame: Columns are ["id", "start", "stop", "status", "covariate", "tdcov"] | ||
| """ | ||
| validate_gen_tdcm_inputs(n, dist, corr, dist_par, model_cens, cens_par, beta, lam) | ||
|
|
||
| # Generate covariate matrix with correlation structure | ||
| mean = [0, 0] | ||
| cov_matrix = [[1, corr], [corr, 1]] | ||
| b = np.random.multivariate_normal(mean, cov_matrix, size=n) | ||
|
|
||
| data = generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b) | ||
|
|
||
| return pd.DataFrame(data, columns=["id", "start", "stop", "status", "covariate", "tdcov"]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| from gen_surv.validate import validate_gen_thmm_inputs | ||
| from gen_surv.censoring import runifcens, rexpocens | ||
|
|
||
| def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: | ||
| """ | ||
| Calculate transition and censoring times for THMM. | ||
|
|
||
| Parameters: | ||
| - z1 (float): Covariate value. | ||
| - cens_par (float): Censoring parameter. | ||
| - beta (list of float): Coefficients for rate modification (length 3). | ||
| - rate (list of float): Base rates (length 3). | ||
| - rfunc (callable): Censoring function, e.g. runifcens or rexpocens. | ||
|
|
||
| Returns: | ||
| - dict with keys 'c', 't12', 't13', 't23' | ||
| """ | ||
| c = rfunc(1, cens_par)[0] | ||
| rate12 = rate[0] * np.exp(beta[0] * z1) | ||
| rate13 = rate[1] * np.exp(beta[1] * z1) | ||
| rate23 = rate[2] * np.exp(beta[2] * z1) | ||
|
|
||
| t12 = np.random.exponential(scale=1 / rate12) | ||
| t13 = np.random.exponential(scale=1 / rate13) | ||
| t23 = np.random.exponential(scale=1 / rate23) | ||
|
|
||
| return {"c": c, "t12": t12, "t13": t13, "t23": t23} | ||
|
|
||
|
|
||
| def gen_thmm(n, model_cens, cens_par, beta, covar, rate): | ||
| """ | ||
| Generate THMM (Time-Homogeneous Markov Model) survival data. | ||
|
|
||
| Parameters: | ||
| - n (int): Number of individuals. | ||
| - model_cens (str): "uniform" or "exponential". | ||
| - cens_par (float): Censoring parameter. | ||
| - beta (list): Length-3 regression coefficients. | ||
| - covar (float): Covariate upper bound. | ||
| - rate (list): Length-3 transition rates. | ||
|
|
||
| Returns: | ||
| - pd.DataFrame: Columns = ["id", "time", "state", "covariate"] | ||
| """ | ||
| validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covar, rate) | ||
| rfunc = runifcens if model_cens == "uniform" else rexpocens | ||
| records = [] | ||
|
|
||
| for k in range(n): | ||
| z1 = np.random.uniform(0, covar) | ||
| trans = calculate_transitions(z1, cens_par, beta, rate, rfunc) | ||
| t12, t13, c = trans["t12"], trans["t13"], trans["c"] | ||
|
|
||
| if min(t12, t13) < c: | ||
| if t12 <= t13: | ||
| time, state = t12, 2 | ||
| else: | ||
| time, state = t13, 3 | ||
| else: | ||
| time, state = c, 1 # censored | ||
|
|
||
| records.append([k + 1, time, state, z1]) | ||
|
|
||
| return pd.DataFrame(records, columns=["id", "time", "state", "covariate"]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.