diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..33df3b1 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,31 @@ +name: Run Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.9" + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: poetry install + + - name: Run tests + run: poetry run pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4518323 --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Poetry virtualenvs +.env +.venv +poetry.lock + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover + +# Pytest +.pytest_cache/ + +# Jupyter Notebook checkpoints +.ipynb_checkpoints + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# MacOS +.DS_Store + +# System files +Thumbs.db +ehthumbs.db + +# Build artifacts +build/ +dist/ +*.egg-info/ + +# Temporary +*.log +*.tmp diff --git a/LICENCE b/LICENCE new file mode 100644 index 0000000..f72d15e --- /dev/null +++ b/LICENCE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 [Diogo Ribeiro] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index cf014df..4c4799f 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,68 @@ -# private_repo_template +# gen_surv + +**gen_surv** is a Python package for simulating survival data under a variety of models, inspired by the R package [`genSurv`](https://cran.r-project.org/package=genSurv). It supports data generation for: + +- Cox Proportional Hazards Models (CPHM) +- Continuous-Time Markov Models (CMM) +- Time-Dependent Covariate Models (TDCM) +- Time-Homogeneous Hidden Markov Models (THMM) + +--- + +## ๐Ÿ“ฆ Installation + +```bash +poetry install +``` +## โœจ Features + +- Consistent interface across models +- Censoring support (`uniform` or `exponential`) +- Easy integration with `pandas` and `NumPy` +- Suitable for benchmarking survival algorithms and teaching + +## ๐Ÿงช Example + +```python +from gen_surv.cphm import gen_cphm + +df = gen_cphm( + n=100, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covar=2.0 +) +print(df.head()) +``` + +## ๐Ÿ”ง Available Generators + +| Function | Description | +|--------------|--------------------------------------------| +| `gen_cphm()` | Cox Proportional Hazards Model | +| `gen_cmm()` | Continuous-Time Multi-State Markov Model | +| `gen_tdcm()` | Time-Dependent Covariate Model | +| `gen_thmm()` | Time-Homogeneous Markov Model | + ```text +genSurvPy/ gen_surv/ -โ”œโ”€โ”€ gen_surv/ -โ”‚ โ”œโ”€โ”€ __init__.py -โ”‚ โ”œโ”€โ”€ cphm.py โ† put CPHM logic here -โ”‚ โ”œโ”€โ”€ validate.py โ† validation functions here -โ”‚ โ”œโ”€โ”€ censoring.py โ† censoring functions here -โ”‚ โ””โ”€โ”€ utils.py โ† for any shared tools -โ”œโ”€โ”€ tests/ -โ”‚ โ”œโ”€โ”€ __init__.py -โ”‚ โ””โ”€โ”€ test_cphm.py โ† tests for CPHM -โ”œโ”€โ”€ pyproject.toml -โ””โ”€โ”€ README.md โ† rename README.rst if you prefer -LICENSE -``` \ No newline at end of file +โ”œโ”€โ”€ cphm.py +โ”œโ”€โ”€ cmm.py +โ”œโ”€โ”€ tdcm.py +โ”œโ”€โ”€ thmm.py +โ”œโ”€โ”€ censoring.py +โ”œโ”€โ”€ validate.py +examples/ +โ”œโ”€โ”€ run_cphm.py +โ”œโ”€โ”€ run_cmm.py +โ”œโ”€โ”€ run_tdcm.py +โ”œโ”€โ”€ run_thmm.py +โ””โ”€โ”€ utils.py # optional for shared config (e.g. seeding) +``` + +## ๐Ÿง  License + +MIT License. See [LICENSE](LICENSE) for details. diff --git a/gen_surv-stubs/gen_surv/__init__.pyi b/gen_surv-stubs/gen_surv/__init__.pyi new file mode 100644 index 0000000..e69de29 diff --git a/gen_surv-stubs/gen_surv/censoring.pyi b/gen_surv-stubs/gen_surv/censoring.pyi new file mode 100644 index 0000000..bb9ffd0 --- /dev/null +++ b/gen_surv-stubs/gen_surv/censoring.pyi @@ -0,0 +1,4 @@ +import numpy as np + +def runifcens(size: int, cens_par: float) -> np.ndarray: ... +def rexpocens(size: int, cens_par: float) -> np.ndarray: ... diff --git a/gen_surv-stubs/gen_surv/cmm.pyi b/gen_surv-stubs/gen_surv/cmm.pyi new file mode 100644 index 0000000..44a810e --- /dev/null +++ b/gen_surv-stubs/gen_surv/cmm.pyi @@ -0,0 +1,5 @@ +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_cmm_inputs as validate_gen_cmm_inputs + +def generate_event_times(z1: float, beta: list, rate: list) -> dict: ... +def gen_cmm(n, model_cens, cens_par, beta, covar, rate): ... diff --git a/gen_surv-stubs/gen_surv/cphm.pyi b/gen_surv-stubs/gen_surv/cphm.pyi new file mode 100644 index 0000000..cac39b3 --- /dev/null +++ b/gen_surv-stubs/gen_surv/cphm.pyi @@ -0,0 +1,6 @@ +import pandas as pd +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_cphm_inputs as validate_gen_cphm_inputs + +def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): ... +def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: ... diff --git a/gen_surv-stubs/gen_surv/tdcm.pyi b/gen_surv-stubs/gen_surv/tdcm.pyi new file mode 100644 index 0000000..2d849c7 --- /dev/null +++ b/gen_surv-stubs/gen_surv/tdcm.pyi @@ -0,0 +1,5 @@ +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_tdcm_inputs as validate_gen_tdcm_inputs + +def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): ... +def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): ... diff --git a/gen_surv-stubs/gen_surv/thmm.pyi b/gen_surv-stubs/gen_surv/thmm.pyi new file mode 100644 index 0000000..915c59d --- /dev/null +++ b/gen_surv-stubs/gen_surv/thmm.pyi @@ -0,0 +1,5 @@ +from gen_surv.censoring import rexpocens as rexpocens, runifcens as runifcens +from gen_surv.validate import validate_gen_thmm_inputs as validate_gen_thmm_inputs + +def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: ... +def gen_thmm(n, model_cens, cens_par, beta, covar, rate): ... diff --git a/gen_surv-stubs/gen_surv/validate.pyi b/gen_surv-stubs/gen_surv/validate.pyi new file mode 100644 index 0000000..4d49ea9 --- /dev/null +++ b/gen_surv-stubs/gen_surv/validate.pyi @@ -0,0 +1,4 @@ +def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): ... +def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ... +def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, model_cens: str, cens_par: float, beta: list, lam: float): ... +def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): ... diff --git a/gen_surv/__main__.py b/gen_surv/__main__.py new file mode 100644 index 0000000..7d4c7b1 --- /dev/null +++ b/gen_surv/__main__.py @@ -0,0 +1,30 @@ +import argparse +import pandas as pd +from gen_surv.cphm import gen_cphm +from gen_surv.cmm import gen_cmm +from gen_surv.tdcm import gen_tdcm +from gen_surv.thmm import gen_thmm + +def run_example(model: str): + if model == "cphm": + df = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + elif model == "cmm": + df = gen_cmm(n=10, model_cens="exponential", cens_par=1.0, + beta=[0.5, 0.2, -0.1], covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0]) + elif model == "tdcm": + df = gen_tdcm(n=10, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2], + model_cens="uniform", cens_par=0.5, beta=[0.1, 0.2, 0.3], lam=1.0) + elif model == "thmm": + df = gen_thmm(n=10, model_cens="uniform", cens_par=0.5, + beta=[0.1, 0.2, 0.3], covar=1.0, rate=[0.5, 0.6, 0.7]) + else: + raise ValueError(f"Unknown model: {model}") + + print(df) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run gen_surv model example.") + parser.add_argument("model", choices=["cphm", "cmm", "tdcm", "thmm"], + help="Model to run (cphm, cmm, tdcm, thmm)") + args = parser.parse_args() + run_example(args.model) diff --git a/gen_surv/bivariate.py b/gen_surv/bivariate.py new file mode 100644 index 0000000..070fbd0 --- /dev/null +++ b/gen_surv/bivariate.py @@ -0,0 +1,40 @@ +import numpy as np + +def sample_bivariate_distribution(n, dist, corr, dist_par): + """ + Generate samples from a bivariate distribution with specified correlation. + + Parameters: + - n (int): Number of samples + - dist (str): 'weibull' or 'exponential' + - corr (float): Correlation coefficient between [-1, 1] + - dist_par (list): Parameters for the marginals + + Returns: + - np.ndarray of shape (n, 2) + """ + if dist not in {"weibull", "exponential"}: + raise ValueError("Only 'weibull' and 'exponential' distributions are supported.") + + # Step 1: Generate correlated standard normals using Cholesky + mean = [0, 0] + cov = [[1, corr], [corr, 1]] + z = np.random.multivariate_normal(mean, cov, size=n) + u = 1 - np.exp(-0.5 * z**2) # transform normals to uniform via chi-squared approx + u = np.clip(u, 1e-10, 1 - 1e-10) # avoid infs in tails + + # Step 2: Transform to marginals + if dist == "exponential": + if len(dist_par) != 2: + raise ValueError("Exponential distribution requires 2 positive rate parameters.") + x1 = -np.log(1 - u[:, 0]) / dist_par[0] + x2 = -np.log(1 - u[:, 1]) / dist_par[1] + + elif dist == "weibull": + if len(dist_par) != 4: + raise ValueError("Weibull distribution requires 4 positive parameters [a1, b1, a2, b2].") + a1, b1, a2, b2 = dist_par + x1 = (-np.log(1 - u[:, 0]) / a1) ** (1 / b1) + x2 = (-np.log(1 - u[:, 1]) / a2) ** (1 / b2) + + return np.column_stack([x1, x2]) diff --git a/gen_surv/tdcm.py b/gen_surv/tdcm.py new file mode 100644 index 0000000..b349137 --- /dev/null +++ b/gen_surv/tdcm.py @@ -0,0 +1,76 @@ +import numpy as np +import pandas as pd +from gen_surv.validate import validate_gen_tdcm_inputs +from gen_surv.bivariate import sample_bivariate_distribution +from gen_surv.censoring import runifcens, rexpocens + +def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): + """ + Generate censored TDCM observations. + + Parameters: + - n (int): Number of individuals + - dist_par (list): Not directly used here (kept for API compatibility) + - model_cens (str): "uniform" or "exponential" + - cens_par (float): Parameter for the censoring model + - beta (list): Length-2 list of regression coefficients + - lam (float): Rate parameter + - b (np.ndarray): Covariate matrix with 2 columns [., z1] + + Returns: + - np.ndarray: Shape (n, 6) with columns: + [id, start, stop, status, covariate1 (z1), covariate2 (z2)] + """ + rfunc = runifcens if model_cens == "uniform" else rexpocens + observations = np.zeros((n, 6)) + + for k in range(n): + z1 = b[k, 1] + c = rfunc(1, cens_par)[0] + u = np.random.uniform() + + # Determine path based on u threshold + threshold = 1 - np.exp(-lam * b[k, 0] * np.exp(beta[0] * z1)) + if u < threshold: + t = -np.log(1 - u) / (lam * np.exp(beta[0] * z1)) + z2 = 0 + else: + t = ( + -np.log(1 - u) + + lam * b[k, 0] * np.exp(beta[0] * z1) * (1 - np.exp(beta[1])) + ) / (lam * np.exp(beta[0] * z1 + beta[1])) + z2 = 1 + + time = min(t, c) + status = int(t <= c) + + observations[k] = [k + 1, 0, time, status, z1, z2] + + return observations + + +def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): + """ + Generate TDCM (Time-Dependent Covariate Model) survival data. + + Parameters: + - n (int): Number of individuals. + - dist (str): "weibull" or "exponential". + - corr (float): Correlation coefficient. + - dist_par (list): Distribution parameters. + - model_cens (str): "uniform" or "exponential". + - cens_par (float): Censoring parameter. + - beta (list): Length-2 regression coefficients. + - lam (float): Lambda rate parameter. + + Returns: + - pd.DataFrame: Columns are ["id", "start", "stop", "status", "covariate", "tdcov"] + """ + validate_gen_tdcm_inputs(n, dist, corr, dist_par, model_cens, cens_par, beta, lam) + + # Generate covariate matrix from bivariate distribution + b = sample_bivariate_distribution(n, dist, corr, dist_par) + + data = generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b) + + return pd.DataFrame(data, columns=["id", "start", "stop", "status", "covariate", "tdcov"]) diff --git a/gen_surv/thmm.py b/gen_surv/thmm.py new file mode 100644 index 0000000..22feebf --- /dev/null +++ b/gen_surv/thmm.py @@ -0,0 +1,66 @@ +import numpy as np +import pandas as pd +from gen_surv.validate import validate_gen_thmm_inputs +from gen_surv.censoring import runifcens, rexpocens + +def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: + """ + Calculate transition and censoring times for THMM. + + Parameters: + - z1 (float): Covariate value. + - cens_par (float): Censoring parameter. + - beta (list of float): Coefficients for rate modification (length 3). + - rate (list of float): Base rates (length 3). + - rfunc (callable): Censoring function, e.g. runifcens or rexpocens. + + Returns: + - dict with keys 'c', 't12', 't13', 't23' + """ + c = rfunc(1, cens_par)[0] + rate12 = rate[0] * np.exp(beta[0] * z1) + rate13 = rate[1] * np.exp(beta[1] * z1) + rate23 = rate[2] * np.exp(beta[2] * z1) + + t12 = np.random.exponential(scale=1 / rate12) + t13 = np.random.exponential(scale=1 / rate13) + t23 = np.random.exponential(scale=1 / rate23) + + return {"c": c, "t12": t12, "t13": t13, "t23": t23} + + +def gen_thmm(n, model_cens, cens_par, beta, covar, rate): + """ + Generate THMM (Time-Homogeneous Markov Model) survival data. + + Parameters: + - n (int): Number of individuals. + - model_cens (str): "uniform" or "exponential". + - cens_par (float): Censoring parameter. + - beta (list): Length-3 regression coefficients. + - covar (float): Covariate upper bound. + - rate (list): Length-3 transition rates. + + Returns: + - pd.DataFrame: Columns = ["id", "time", "state", "covariate"] + """ + validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covar, rate) + rfunc = runifcens if model_cens == "uniform" else rexpocens + records = [] + + for k in range(n): + z1 = np.random.uniform(0, covar) + trans = calculate_transitions(z1, cens_par, beta, rate, rfunc) + t12, t13, c = trans["t12"], trans["t13"], trans["c"] + + if min(t12, t13) < c: + if t12 <= t13: + time, state = t12, 2 + else: + time, state = t13, 3 + else: + time, state = c, 1 # censored + + records.append([k + 1, time, state, z1]) + + return pd.DataFrame(records, columns=["id", "time", "state", "covariate"]) diff --git a/gen_surv/validate.py b/gen_surv/validate.py index 57a9b8f..e050e0b 100644 --- a/gen_surv/validate.py +++ b/gen_surv/validate.py @@ -14,12 +14,14 @@ def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: fl if n <= 0: raise ValueError("Argument 'n' must be greater than 0") if model_cens not in {"uniform", "exponential"}: - raise ValueError("Argument 'model_cens' must be one of 'uniform' or 'exponential'") + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") if cens_par <= 0: raise ValueError("Argument 'cens_par' must be greater than 0") if covar <= 0: raise ValueError("Argument 'covar' must be greater than 0") + def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): """ Validate inputs for generating CMM (Continuous-Time Markov Model) data. @@ -38,7 +40,8 @@ def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list if n <= 0: raise ValueError("Argument 'n' must be greater than 0") if model_cens not in {"uniform", "exponential"}: - raise ValueError("Argument 'model_cens' must be one of 'uniform' or 'exponential'") + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") if cens_par <= 0: raise ValueError("Argument 'cens_par' must be greater than 0") if len(beta) != 3: @@ -47,3 +50,127 @@ def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list raise ValueError("Argument 'covar' must be greater than 0") if len(rate) != 6: raise ValueError("Argument 'rate' must be a list of length 6") + + +def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, + model_cens: str, cens_par: float, beta: list, lam: float): + """ + Validate inputs for generating TDCM (Time-Dependent Covariate Model) data. + + Parameters: + - n (int): Number of observations. + - dist (str): "weibull" or "exponential". + - corr (float): Correlation coefficient. + - dist_par (list): Distribution parameters. + - model_cens (str): "uniform" or "exponential". + - cens_par (float): Censoring parameter. + - beta (list): Length-2 list of regression coefficients. + - lam (float): Lambda parameter, must be > 0. + + Raises: + - ValueError: For any invalid input. + """ + if n <= 0: + raise ValueError("Argument 'n' must be greater than 0") + + if dist not in {"weibull", "exponential"}: + raise ValueError( + "Argument 'dist' must be one of 'weibull' or 'exponential'") + + if dist == "weibull": + if not (0 < corr <= 1): + raise ValueError("With dist='weibull', 'corr' must be in (0,1]") + if len(dist_par) != 4 or any(p <= 0 for p in dist_par): + raise ValueError( + "With dist='weibull', 'dist_par' must be a positive list of length 4") + + if dist == "exponential": + if not (-1 <= corr <= 1): + raise ValueError( + "With dist='exponential', 'corr' must be in [-1,1]") + if len(dist_par) != 2 or any(p <= 0 for p in dist_par): + raise ValueError( + "With dist='exponential', 'dist_par' must be a positive list of length 2") + + if model_cens not in {"uniform", "exponential"}: + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + + if cens_par <= 0: + raise ValueError("Argument 'cens_par' must be greater than 0") + + if not isinstance(beta, list) or len(beta) != 3: + raise ValueError("Argument 'beta' must be a list of length 3") + + if lam <= 0: + raise ValueError("Argument 'lambda' must be greater than 0") + + +def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): + """ + Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. + + Parameters: + - n (int): Number of samples, must be > 0. + - model_cens (str): Must be "uniform" or "exponential". + - cens_par (float): Must be > 0. + - beta (list): List of length 3 (regression coefficients). + - covar (float): Positive covariate value. + - rate (list): List of length 3 (transition rates). + + Raises: + - ValueError if any input is invalid. + """ + if not isinstance(n, int) or n <= 0: + raise ValueError("Argument 'n' must be a positive integer.") + + if model_cens not in {"uniform", "exponential"}: + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + + if not isinstance(cens_par, (int, float)) or cens_par <= 0: + raise ValueError("Argument 'cens_par' must be a positive number.") + + if not isinstance(beta, list) or len(beta) != 3: + raise ValueError("Argument 'beta' must be a list of length 3.") + + if not isinstance(covar, (int, float)) or covar <= 0: + raise ValueError("Argument 'covar' must be greater than 0.") + + if not isinstance(rate, list) or len(rate) != 3: + raise ValueError("Argument 'rate' must be a list of length 3.") + + +def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): + """ + Validate inputs for the sample_bivariate_distribution function. + + Parameters: + - n (int): Number of samples to generate. + - dist (str): Must be "weibull" or "exponential". + - corr (float): Must be between -1 and 1. + - dist_par (list): Must contain positive values, and correct length for the distribution. + + Raises: + - ValueError if any input is invalid. + """ + if not isinstance(n, int) or n <= 0: + raise ValueError("Argument 'n' must be a positive integer.") + + if dist not in {"weibull", "exponential"}: + raise ValueError("Argument 'dist' must be one of 'weibull' or 'exponential'.") + + if not isinstance(corr, (int, float)) or not (-1 < corr < 1): + raise ValueError("Argument 'corr' must be a numeric value between -1 and 1.") + + if not isinstance(dist_par, list) or len(dist_par) == 0: + raise ValueError("Argument 'dist_par' must be a non-empty list of positive values.") + + if any(p <= 0 for p in dist_par): + raise ValueError("All elements in 'dist_par' must be greater than 0.") + + if dist == "exponential" and len(dist_par) != 2: + raise ValueError("Exponential distribution requires exactly 2 positive parameters.") + + if dist == "weibull" and len(dist_par) != 4: + raise ValueError("Weibull distribution requires exactly 4 positive parameters.") diff --git a/pyproject.toml b/pyproject.toml index 929eece..9e0a8b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,11 @@ packages = [{ include = "gen_surv" }] [tool.poetry.dependencies] python = "^3.9" numpy = "^1.26" +pandas = "^2.2.3" -[tool.poetry.dev-dependencies] -pytest = "^8.0" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.3.5" [tool.semantic_release] version_source = "tag" diff --git a/tests/test_cmm.py b/tests/test_cmm.py new file mode 100644 index 0000000..48f2467 --- /dev/null +++ b/tests/test_cmm.py @@ -0,0 +1,10 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from gen_surv.cmm import gen_cmm + +def test_gen_cmm_shape(): + df = gen_cmm(n=50, model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], + covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0]) + assert df.shape[1] == 6 + assert "transition" in df.columns diff --git a/tests/test_cphm.py b/tests/test_cphm.py new file mode 100644 index 0000000..05cc652 --- /dev/null +++ b/tests/test_cphm.py @@ -0,0 +1,13 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from gen_surv.cphm import gen_cphm + +def test_gen_cphm_output_shape(): + df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + assert df.shape == (50, 3) + assert list(df.columns) == ["time", "status", "covariate"] + +def test_gen_cphm_status_range(): + df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covar=1.5) + assert df["status"].isin([0, 1]).all() diff --git a/tests/test_tdcm.py b/tests/test_tdcm.py new file mode 100644 index 0000000..507b51f --- /dev/null +++ b/tests/test_tdcm.py @@ -0,0 +1,10 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from gen_surv.tdcm import gen_tdcm + +def test_gen_tdcm_shape(): + df = gen_tdcm(n=50, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2], + model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0) + assert df.shape[1] == 6 + assert "tdcov" in df.columns diff --git a/tests/test_thmm.py b/tests/test_thmm.py new file mode 100644 index 0000000..b53b197 --- /dev/null +++ b/tests/test_thmm.py @@ -0,0 +1,11 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from gen_surv.thmm import gen_thmm + +def test_gen_thmm_shape(): + df = gen_thmm(n=50, model_cens="uniform", cens_par=1.0, + beta=[0.1, 0.2, 0.3], covar=2.0, rate=[0.5, 0.6, 0.7]) + assert df.shape[1] == 4 + assert set(df["state"].unique()).issubset({1, 2, 3})