diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..479d864 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,18 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it using the metadata below." + +# Basic information +preferred-citation: + type: software + title: "gen_surv" + version: "1.0.0" + url: "https://github.com/DiogoRibeiro7/genSurvPy" + authors: + - family-names: Ribeiro + given-names: Diogo + orcid: "https://orcid.org/0009-0001-2022-7072" + affiliation: "ESMAD - Instituto PolitΓ©cnico do Porto" + email: "dfr@esmad.ipp.pt" + license: "MIT" + date-released: "2024-01-01" + diff --git a/README.md b/README.md index 2fc5bef..b4f34af 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ poetry install - Easy integration with `pandas` and `NumPy` - Suitable for benchmarking survival algorithms and teaching - Accelerated Failure Time (Log-Normal) model generator +- Command-line interface powered by `Typer` ## πŸ§ͺ Example @@ -55,6 +56,15 @@ generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) ``` +## ⌨️ Command-Line Usage + +Install the package and use ``python -m gen_surv`` to generate datasets without +writing Python code: + +```bash +python -m gen_surv dataset aft_ln --n 100 > data.csv +``` + ## πŸ”§ Available Generators | Function | Description | @@ -115,3 +125,9 @@ expectations for participants in this project. ## 🀝 Contributing Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on setting up your environment, running tests, and submitting pull requests. + +## πŸ“‘ Citation + +If you use **gen_surv** in your work, please cite it using the metadata in +[`CITATION.cff`](CITATION.cff). Many reference managers can import this file +directly. diff --git a/TODO.md b/TODO.md index 21ebe2e..91b4fe4 100644 --- a/TODO.md +++ b/TODO.md @@ -4,11 +4,22 @@ This document outlines future enhancements, features, and ideas for improving th --- +## ✨ Priority Items + +- [βœ…] Add property-based tests using Hypothesis to cover edge cases +- [βœ…] Build a CLI for generating datasets from the terminal +- [ ] Expand documentation with multilingual support and more usage examples +- [ ] Implement Weibull and log-logistic AFT models and add visualization utilities +- [βœ…] Provide CITATION metadata for proper referencing +- [ ] Ensure all functions include Google-style docstrings with inline comments + +--- + ## πŸ“¦ 1. Interface and UX - [βœ…] Create a `generate(..., return_type="df" | "dict")` interface -- [ ] Add `__version__` using `importlib.metadata` or `poetry-dynamic-versioning` -- [ ] Build a CLI with `typer` or `click` +- [βœ…] Add `__version__` using `importlib.metadata` or `poetry-dynamic-versioning` +- [βœ…] Build a CLI with `typer` or `click` - [βœ…] Add example notebooks or scripts for each model (`examples/` folder) --- @@ -25,7 +36,7 @@ This document outlines future enhancements, features, and ideas for improving th ## πŸ§ͺ 3. Testing and Quality - [βœ…] Add tests for each model (e.g., `test_tdcm.py`, `test_thmm.py`, `test_aft.py`) -- [ ] Add property-based tests with `hypothesis` +- [βœ…] Add property-based tests with `hypothesis` - [ ] Cover edge cases (e.g., invalid parameters, n=0, negative censoring) - [ ] Run tests on multiple Python versions (CI matrix) diff --git a/docs/source/index.md b/docs/source/index.md index 69011e3..83e034d 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -49,6 +49,14 @@ generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) ``` +## ⌨️ Command-Line Usage + +Generate datasets directly from the terminal: + +```bash +python -m gen_surv dataset aft_ln --n 100 > data.csv +``` + ## πŸ”— Project Links - [Source Code](https://github.com/DiogoRibeiro7/genSurvPy) diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index e27e4b9..8939886 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -1 +1,17 @@ -from .interface import generate \ No newline at end of file +"""Top-level package for ``gen_surv``. + +This module exposes the :func:`generate` function and provides access to the +package version via ``__version__``. +""" + +from importlib.metadata import PackageNotFoundError, version + +from .interface import generate + +try: + __version__ = version("gen_surv") +except PackageNotFoundError: # pragma: no cover - fallback when package not installed + __version__ = "0.0.0" + +__all__ = ["generate", "__version__"] + diff --git a/gen_surv/__main__.py b/gen_surv/__main__.py index 7d4c7b1..2bef122 100644 --- a/gen_surv/__main__.py +++ b/gen_surv/__main__.py @@ -1,30 +1,4 @@ -import argparse -import pandas as pd -from gen_surv.cphm import gen_cphm -from gen_surv.cmm import gen_cmm -from gen_surv.tdcm import gen_tdcm -from gen_surv.thmm import gen_thmm - -def run_example(model: str): - if model == "cphm": - df = gen_cphm(n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) - elif model == "cmm": - df = gen_cmm(n=10, model_cens="exponential", cens_par=1.0, - beta=[0.5, 0.2, -0.1], covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0]) - elif model == "tdcm": - df = gen_tdcm(n=10, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2], - model_cens="uniform", cens_par=0.5, beta=[0.1, 0.2, 0.3], lam=1.0) - elif model == "thmm": - df = gen_thmm(n=10, model_cens="uniform", cens_par=0.5, - beta=[0.1, 0.2, 0.3], covar=1.0, rate=[0.5, 0.6, 0.7]) - else: - raise ValueError(f"Unknown model: {model}") - - print(df) +from gen_surv.cli import app if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run gen_surv model example.") - parser.add_argument("model", choices=["cphm", "cmm", "tdcm", "thmm"], - help="Model to run (cphm, cmm, tdcm, thmm)") - args = parser.parse_args() - run_example(args.model) + app() diff --git a/gen_surv/cli.py b/gen_surv/cli.py new file mode 100644 index 0000000..542ea51 --- /dev/null +++ b/gen_surv/cli.py @@ -0,0 +1,36 @@ +import csv +from typing import Optional +import typer +from gen_surv.interface import generate + +app = typer.Typer(help="Generate synthetic survival datasets.") + +@app.command() +def dataset( + model: str = typer.Argument( + ..., help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln]" + ), + n: int = typer.Option(100, help="Number of samples"), + output: Optional[str] = typer.Option( + None, "-o", help="Output CSV file. Prints to stdout if omitted." + ), +) -> None: + """Generate survival data and optionally save to CSV. + + Args: + model: Identifier of the generator to use. + n: Number of samples to create. + output: Optional path to save the CSV file. + + Returns: + None + """ + df = generate(model=model, n=n) + if output: + df.to_csv(output, index=False) + typer.echo(f"Saved dataset to {output}") + else: + typer.echo(df.to_csv(index=False)) + +if __name__ == "__main__": + app() diff --git a/gen_surv/interface.py b/gen_surv/interface.py index 65e53fd..549ad45 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -6,6 +6,9 @@ >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) """ +from typing import Any +import pandas as pd + from gen_surv.cphm import gen_cphm from gen_surv.cmm import gen_cmm from gen_surv.tdcm import gen_tdcm @@ -22,13 +25,13 @@ } -def generate(model: str, **kwargs): - """ - Generic interface to generate survival data from various models. +def generate(model: str, **kwargs: Any) -> pd.DataFrame: + """Generate survival data from a specific model. - Parameters: - model (str): One of ["cphm", "cmm", "tdcm", "thmm"] - **kwargs: Arguments forwarded to the selected model generator. + Args: + model: Name of the generator to run. Must be one of ``cphm``, ``cmm``, + ``tdcm``, ``thmm`` or ``aft_ln``. + **kwargs: Arguments forwarded to the chosen generator. Returns: pd.DataFrame: Simulated survival data. diff --git a/pyproject.toml b/pyproject.toml index 9b648cf..8f9b94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gen_surv" -version = "1.0.1" +version = "1.0.0" description = "A Python package for simulating survival data, inspired by the R package genSurv" authors = ["Diogo Ribeiro "] license = "MIT" @@ -16,12 +16,14 @@ numpy = "^1.26" pandas = "^2.2.3" pytest-cov = "^6.1.1" invoke = "^2.2.0" +typer = "^0.12.3" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" python-semantic-release = "^9.21.0" mypy = "^1.15.0" invoke = "^2.2.0" +hypothesis = "^6.98" [tool.poetry.group.docs.dependencies] diff --git a/tasks.py b/tasks.py index 693f162..d0389b1 100644 --- a/tasks.py +++ b/tasks.py @@ -5,6 +5,7 @@ + @task def test(c: Context) -> None: """ @@ -49,6 +50,33 @@ def test(c: Context) -> None: print(stderr_output) +@task +def check_version(c: Context) -> None: + """Validate that ``pyproject.toml`` matches the latest git tag. + + This task runs the ``scripts/check_version_match.py`` helper using Poetry + and reports whether the version numbers are aligned. + + Args: + c: Invoke context used to run shell commands. + + Returns: + None + """ + if not isinstance(c, Context): + raise TypeError(f"Expected Invoke Context, got {type(c).__name__!r}") + + # Execute the version check script with Poetry. + cmd = "poetry run python scripts/check_version_match.py" + result = c.run(cmd, warn=True, pty=False) + + # Report based on the exit code from the script. + if result.ok: + print("βœ”οΈ pyproject version matches the latest git tag.") + else: + print("❌ Version mismatch detected.") + print(result.stderr) + @task def docs(c: Context) -> None: """ diff --git a/tests/test_aft_property.py b/tests/test_aft_property.py new file mode 100644 index 0000000..18a6412 --- /dev/null +++ b/tests/test_aft_property.py @@ -0,0 +1,22 @@ +from hypothesis import given, strategies as st +from gen_surv.aft import gen_aft_log_normal + +@given( + n=st.integers(min_value=1, max_value=20), + sigma=st.floats(min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False), + cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), + seed=st.integers(min_value=0, max_value=1000) +) +def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): + df = gen_aft_log_normal( + n=n, + beta=[0.5, -0.2], + sigma=sigma, + model_cens="uniform", + cens_par=cens_par, + seed=seed + ) + assert df.shape[0] == n + assert set(df["status"].unique()).issubset({0, 1}) + assert (df["time"] >= 0).all() + assert df.filter(regex="^X[0-9]+$").shape[1] == 2 diff --git a/tests/test_bivariate.py b/tests/test_bivariate.py new file mode 100644 index 0000000..b011130 --- /dev/null +++ b/tests/test_bivariate.py @@ -0,0 +1,26 @@ +import numpy as np +from gen_surv.bivariate import sample_bivariate_distribution +import pytest + + +def test_sample_bivariate_exponential_shape(): + """Exponential distribution should return an array of shape (n, 2).""" + result = sample_bivariate_distribution(5, "exponential", 0.0, [1.0, 1.0]) + assert isinstance(result, np.ndarray) + assert result.shape == (5, 2) + + +def test_sample_bivariate_invalid_dist(): + """Unsupported distributions should raise ValueError.""" + with pytest.raises(ValueError): + sample_bivariate_distribution(10, "invalid", 0.0, [1, 1]) + +def test_sample_bivariate_exponential_param_length_error(): + """Exponential distribution with wrong param length should raise ValueError.""" + with pytest.raises(ValueError): + sample_bivariate_distribution(5, "exponential", 0.0, [1.0]) + +def test_sample_bivariate_weibull_param_length_error(): + """Weibull distribution with wrong param length should raise ValueError.""" + with pytest.raises(ValueError): + sample_bivariate_distribution(5, "weibull", 0.0, [1.0, 1.0]) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..d5fd8d7 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,45 @@ +import pandas as pd +from gen_surv.cli import dataset +import runpy + + +def test_cli_dataset_stdout(monkeypatch, capsys): + """Dataset command prints CSV to stdout when no output file is given.""" + + def fake_generate(model: str, n: int): + return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) + + # Patch the generate function used in the CLI to avoid heavy computation. + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + # Call the command function directly to sidestep Click argument parsing + dataset(model="cphm", n=1, output=None) + captured = capsys.readouterr() + assert "time,status,X0,X1" in captured.out + + +def test_main_entry_point(monkeypatch): + """Running the module as a script should invoke the CLI app.""" + + called = [] + + def fake_app(): + called.append(True) + + # Patch the CLI app before the module is executed + monkeypatch.setattr("gen_surv.cli.app", fake_app) + monkeypatch.setattr("sys.argv", ["gen_surv", "dataset", "cphm"]) + runpy.run_module("gen_surv.__main__", run_name="__main__") + assert called + +def test_cli_dataset_file_output(monkeypatch, tmp_path): + """Dataset command writes CSV to file when output path is provided.""" + + def fake_generate(model: str, n: int): + return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + out_file = tmp_path / "out.csv" + dataset(model="cphm", n=1, output=str(out_file)) + assert out_file.exists() + content = out_file.read_text() + assert "time,status,X0,X1" in content diff --git a/tests/test_interface.py b/tests/test_interface.py index cee6a81..ad2dc3c 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,4 +1,6 @@ from gen_surv import generate +import pytest + def test_generate_tdcm_runs(): df = generate( @@ -10,7 +12,11 @@ def test_generate_tdcm_runs(): model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], - lam=1.0 + lam=1.0, ) assert not df.empty - \ No newline at end of file + + +def test_generate_invalid_model(): + with pytest.raises(ValueError): + generate(model="unknown") diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..ffbbad0 --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,72 @@ +import pytest +import gen_surv.validate as v + + +def test_validate_gen_cphm_inputs_valid(): + """Ensure valid inputs pass without raising an exception.""" + v.validate_gen_cphm_inputs(1, "uniform", 0.5, 1.0) + + +@pytest.mark.parametrize( + "n, model_cens, cens_par, covar", + [ + (0, "uniform", 0.5, 1.0), + (1, "bad", 0.5, 1.0), + (1, "uniform", -1.0, 1.0), + (1, "uniform", 0.5, -1.0), + ], +) +def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covar): + """Invalid parameter combinations should raise ValueError.""" + with pytest.raises(ValueError): + v.validate_gen_cphm_inputs(n, model_cens, cens_par, covar) + + +def test_validate_dg_biv_inputs_invalid(): + """Invalid distribution names should raise an error.""" + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(10, "normal", 0.1, [1, 1]) + + +def test_validate_gen_cmm_inputs_invalid_beta_length(): + """Invalid beta length should raise a ValueError.""" + with pytest.raises(ValueError): + v.validate_gen_cmm_inputs( + 1, + "uniform", + 0.5, + [0.1, 0.2], + covar=1.0, + rate=[0.1] * 6, + ) + + +def test_validate_gen_tdcm_inputs_invalid_lambda(): + """Lambda <= 0 should raise a ValueError.""" + with pytest.raises(ValueError): + v.validate_gen_tdcm_inputs( + 1, + "weibull", + 0.5, + [1, 2, 1, 2], + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=0, + ) + + +def test_validate_gen_aft_log_normal_inputs_valid(): + """Valid parameters should not raise an error for AFT log-normal.""" + v.validate_gen_aft_log_normal_inputs( + 1, + [0.1, 0.2], + 1.0, + "uniform", + 0.5, + ) + + +def test_validate_dg_biv_inputs_valid_weibull(): + """Valid parameters for a Weibull distribution should pass.""" + v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0]) diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 0000000..6fb4a02 --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,7 @@ +from importlib.metadata import version +from gen_surv import __version__ + + +def test_package_version_matches_metadata(): + """The exported __version__ should match package metadata.""" + assert __version__ == version("gen_surv")