diff --git a/.github/workflows/bump-version.yml b/.github/workflows/bump-version.yml index 5b41bd7..f99cd8b 100644 --- a/.github/workflows/bump-version.yml +++ b/.github/workflows/bump-version.yml @@ -1,67 +1,26 @@ # .github/workflows/bump-version.yml -name: Bump Version on Merge to Main - +name: Sync Version on Merge to Main + on: push: branches: - main jobs: - bump-version: + tag-version: runs-on: ubuntu-latest permissions: contents: write - id-token: write steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 # Fetch all history for all branches and tags + fetch-depth: 0 - uses: actions/setup-python@v5 with: python-version: "3.11" - - run: pip install python-semantic-release - - - name: Configure Git - run: | - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - - - name: Run Semantic Release - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: semantic-release version + - name: Synchronize tag and pyproject version + run: python scripts/check_version_match.py --fix --write - - name: Push changes - run: | - git push --follow-tags - - name: "Install Poetry" - run: pip install poetry - - name: "Determine version bump type" - run: | - git fetch --tags - # This defaults to a patch type, unless a feature commit was pushed, then set type to minor - LAST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1)) - LAST_COMMIT=$(git log -1 --format='%H') - echo "Last git tag: $LAST_TAG" - echo "Last git commit: $LAST_COMMIT" - echo "Commits:" - git log --no-merges --pretty=oneline $LAST_TAG...$LAST_COMMIT - git log --no-merges --pretty=format:"%s" $LAST_TAG...$LAST_COMMIT | grep -q ^feat: && BUMP_TYPE="minor" || BUMP_TYPE="patch" - echo "Version bump type: $BUMP_TYPE" - echo "BUMP_TYPE=$BUMP_TYPE" >> $GITHUB_ENV - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: "Version bump" - run: | - poetry version $BUMP_TYPE - - name: "Push new version" - run: | - git add pyproject.toml - git commit -m "Update version to $(poetry version -s)" - git pull --ff-only origin main - git push origin main --follow-tags - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..bc33b82 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,96 @@ +name: CI + +on: + push: + branches: + - main + tags: + - 'v*' + pull_request: + branches: + - main + +jobs: + test: + name: Test with Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache Poetry dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry + key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: poetry install --with dev --no-interaction + + - name: Check runtime dependencies + run: | + poetry run python - <<'PY' + import importlib, sys + missing = [m for m in ("pandas",) if importlib.util.find_spec(m) is None] + if missing: + print("Missing dependencies:", ", ".join(missing)) + sys.exit(1) + PY + + - name: Verify version matches tag + if: startsWith(github.ref, 'refs/tags/') + run: python scripts/check_version_match.py + + - name: Run tests with coverage + run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term + + - name: Upload coverage to Codecov + if: matrix.python-version == '3.11' + uses: codecov/codecov-action@v5 + with: + files: coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + + lint: + name: Code Quality + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Cache Poetry dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry + key: ${{ runner.os }}-poetry-3.11-${{ hashFiles('**/poetry.lock') }} + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: poetry install --with dev --no-interaction + + - name: Run pre-commit checks + run: poetry run pre-commit run --all-files diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..9bf55bb --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,34 @@ +name: Documentation Build + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Cache Poetry dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry + key: ${{ runner.os }}-poetry-3.11-${{ hashFiles('**/poetry.lock') }} + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + - name: Install dependencies + run: poetry install --with docs + - name: Build documentation + run: poetry run sphinx-build -W -b html docs/source docs/build + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: documentation + path: docs/build/ diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..a3928e0 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,27 @@ +name: Publish to PyPI + +on: + workflow_dispatch: + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Verify version matches tag + run: python scripts/check_version_match.py + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + - name: Publish + env: + POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} + run: poetry publish --build --no-interaction diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index bcd9349..0000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Run Tests - -on: - push: - branches: [main] - pull_request: - branches: [main] - -jobs: - test: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.9" - - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH - - - name: Install dependencies - run: poetry install - - - name: Run tests - run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term - - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 - with: - files: coverage.xml - token: ${{ secrets.CODECOV_TOKEN }} # optional if public repo diff --git a/.gitignore b/.gitignore index 4518323..6144971 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ dist/ # Temporary *.log *.tmp +.hypothesis/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..cf1f872 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.1.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/pycqa/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-pyproject] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + pass_filenames: false + args: [--config-file=pyproject.toml, gen_surv] diff --git a/.readthedocs.yml b/.readthedocs.yml index a811ac3..a1f8a6e 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,10 +4,22 @@ build: os: ubuntu-22.04 tools: python: "3.11" + jobs: + post_create_environment: + - pip install poetry + post_install: + - poetry config virtualenvs.create false + - poetry install --with docs python: install: - - requirements: docs/requirements.txt + - method: pip + path: . sphinx: configuration: docs/source/conf.py + fail_on_warning: false + +formats: + - pdf + - epub diff --git a/CHANGELOG.md b/CHANGELOG.md index 500ee3d..c3ad68d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # CHANGELOG +## v1.0.9 (2025-08-02) + +### Features +- export datasets to RDS files +- test workflow runs on a Python version matrix +- scikit-learn compatible data generator +- compatibility helpers for lifelines and scikit-survival + +### Documentation +- updated usage examples and tutorials +- document optional scikit-survival dependency throughout the docs + +### Continuous Integration +- auto-tag releases using the version check script + +### Misc +- README quick example uses `covariate_range` + +## v1.0.8 (2025-07-30) + +### Documentation +- ensure absolute path resolution in `conf.py` +- drop unsupported theme option +- define bibliography anchors and headings +- fix tutorial links to non-existing docs +- add additional references to the bibliography + +### Testing +- add CLI integration test +- expand piecewise generator test coverage + +### Misc +- remove fix_recommendations.md + + ## v1.0.0 (2025-06-06) @@ -186,9 +221,6 @@ - Update pyproject ([`a4b25e4`](https://github.com/DiogoRibeiro7/genSurvPy/commit/a4b25e470954091254b1384a44a991a47341bf80)) -- Work - ([`5ac5130`](https://github.com/DiogoRibeiro7/genSurvPy/commit/5ac513098238a8298430d1a95c6fbeed99db4cad)) - ### Continuous Integration - Add GitHub Actions workflow for test automation @@ -218,6 +250,3 @@ - Implement THMM data generator and finalize full model suite ([`1e667ba`](https://github.com/DiogoRibeiro7/genSurvPy/commit/1e667babf28892c3a85c43477562f2de85f07f3c)) - -- Work - ([`45de359`](https://github.com/DiogoRibeiro7/genSurvPy/commit/45de359bbb0d7fbc671e41fa07d3a37b09e68e18)) diff --git a/CHECKLIST.md b/CHECKLIST.md new file mode 100644 index 0000000..0163f08 --- /dev/null +++ b/CHECKLIST.md @@ -0,0 +1,102 @@ +# βœ… Python Package Development Checklist + +A checklist to ensure quality, maintainability, and usability of your Python package. + +--- + +## 1. Purpose & Scope + +- [ ] Clear purpose and use cases defined +- [ ] Scoped to a specific problem/domain +- [ ] Project name is meaningful and not taken on PyPI + +--- + +## 2. Project Structure + +- [ ] Uses `src/` layout or appropriate flat structure +- [ ] All package folders contain `__init__.py` +- [ ] Main configuration handled via `pyproject.toml` +- [ ] Includes standard files: `README.md`, `LICENSE`, `.gitignore`, `CHANGELOG.md` + +--- + +## 3. Dependencies + +- [ ] All dependencies declared in `pyproject.toml` or `requirements.txt` +- [ ] Development dependencies separated from runtime dependencies +- [ ] Uses minimal, necessary dependencies only + +--- + +## 4. Code Quality + +- [ ] Follows PEP 8 formatting +- [ ] Imports sorted with `isort` or `ruff` +- [ ] No linter warnings (`ruff`, `flake8`, etc.) +- [ ] Fully typed using `typing` module +- [ ] No unresolved TODOs or FIXME comments + +--- + +## 5. Function & Module Design + +- [ ] Functions are small, pure, and single-responsibility +- [ ] Classes follow clear and simple roles +- [ ] Global state is avoided +- [ ] Public API defined explicitly (e.g. via `__all__`) + +--- + +## 6. Documentation + +- [ ] `README.md` includes overview, install, usage, contributing +- [ ] All functions/classes include docstrings (Google/NumPy style) +- [ ] API reference documentation auto-generated (e.g., Sphinx, MkDocs) +- [ ] Optional: `docs/` folder for additional documentation or site generator + +--- + +## 7. Testing + +- [ ] Unit and integration tests implemented +- [ ] Test coverage > 80% verified by `coverage` +- [ ] Tests are fast and deterministic +- [ ] Continuous Integration (CI) configured to run tests + +--- + +## 8. Versioning & Releases + +- [ ] Uses semantic versioning (MAJOR.MINOR.PATCH) +- [ ] Git tags created for releases +- [ ] `CHANGELOG.md` updated with each release +- [ ] Local build verified (`poetry build`, `hatch build`, or equivalent) +- [ ] Can be published to PyPI and/or TestPyPI + +--- + +## 9. CLI or Scripts (Optional) + +- [ ] CLI entrypoint works correctly (`__main__.py` or `entry_points`) +- [ ] CLI provides helpful messages (`--help`) and handles errors gracefully + +--- + +## 10. Examples / Tutorials + +- [ ] Usage examples included in `README.md` or `examples/` +- [ ] Optional: Jupyter notebooks with demonstrations +- [ ] Optional: Colab or Binder links for live usage + +--- + +## 11. Licensing & Attribution + +- [ ] LICENSE file included (MIT, Apache 2.0, GPL, etc.) +- [ ] Author and contributors credited in `README.md` +- [ ] Optional: `CITATION.cff` file for academic citation + +--- + +> You can duplicate this file for each new package or use it as a GitHub issue template for release checklists. diff --git a/CITATION.cff b/CITATION.cff index fe3f8e0..78d9c81 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,14 +5,15 @@ message: "If you use this software, please cite it using the metadata below." preferred-citation: type: software title: "gen_surv" - version: "1.0.3" + version: "1.0.9" url: "https://github.com/DiogoRibeiro7/genSurvPy" authors: - family-names: Ribeiro given-names: Diogo + alias: DiogoRibeiro7 orcid: "https://orcid.org/0009-0001-2022-7072" affiliation: "ESMAD - Instituto PolitΓ©cnico do Porto" email: "dfr@esmad.ipp.pt" license: "MIT" - date-released: "2024-01-01" + date-released: "2025-08-02" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 07dc4b6..aba48fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,9 +5,14 @@ Thank you for taking the time to contribute to **gen_surv**! This document provi ## Getting Started 1. Fork the repository and create your feature branch from `main`. -2. Install dependencies with `poetry install`. -3. Ensure the test suite passes with `poetry run pytest`. -4. If you add a feature or fix a bug, update `CHANGELOG.md` accordingly. +2. Install dependencies with `poetry install --with dev`. + This installs all packages needed for development, including + the optional dependency `scikit-survival`. + On Debian/Ubuntu you may need `build-essential gfortran libopenblas-dev` + to build it. +3. Run `pre-commit install` to enable style checks and execute them with `pre-commit run --all-files`. +4. Ensure the test suite passes with `poetry run pytest`. +5. If you add a feature or fix a bug, update `CHANGELOG.md` accordingly. ## Version Consistency diff --git a/LICENCE b/LICENSE similarity index 100% rename from LICENCE rename to LICENSE diff --git a/README.md b/README.md index b74802b..df4beab 100644 --- a/README.md +++ b/README.md @@ -1,157 +1,142 @@ # gen_surv -![Coverage](https://codecov.io/gh/DiogoRibeiro7/genSurvPy/branch/main/graph/badge.svg) -[![Docs](https://readthedocs.org/projects/gensurvpy/badge/?version=stable)](https://gensurvpy.readthedocs.io/en/stable/) -![PyPI](https://img.shields.io/pypi/v/gen_surv) -![Tests](https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/test.yml/badge.svg) -![Python](https://img.shields.io/pypi/pyversions/gen_surv) +[![Coverage][cov-badge]][cov-link] +[![Docs][docs-badge]][docs-link] +[![PyPI][pypi-badge]][pypi-link] +[![Tests][ci-badge]][ci-link] +[![Python][py-badge]][pypi-link] + +[cov-badge]: https://codecov.io/gh/DiogoRibeiro7/genSurvPy/branch/main/graph/badge.svg +[cov-link]: https://app.codecov.io/gh/DiogoRibeiro7/genSurvPy +[docs-badge]: https://readthedocs.org/projects/gensurvpy/badge/?version=latest +[docs-link]: https://gensurvpy.readthedocs.io/en/latest/ +[pypi-badge]: https://img.shields.io/pypi/v/gen_surv +[pypi-link]: https://pypi.org/project/gen-surv/ +[ci-badge]: https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/ci.yml/badge.svg +[ci-link]: https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/ci.yml +[py-badge]: https://img.shields.io/pypi/pyversions/gen_surv + +**gen_surv** is a Python library for simulating survival data under a wide range of statistical models. Inspired by the R package [genSurv](https://cran.r-project.org/package=genSurv), it offers a unified interface for generating realistic datasets for research, teaching and benchmarking. +--- -**gen_surv** is a Python package for simulating survival data under a variety of models, inspired by the R package [`genSurv`](https://cran.r-project.org/package=genSurv). It supports data generation for: +## Features -- Cox Proportional Hazards Models (CPHM) -- Continuous-Time Markov Models (CMM) -- Time-Dependent Covariate Models (TDCM) -- Time-Homogeneous Hidden Markov Models (THMM) +- Cox proportional hazards model (CPHM) +- Accelerated failure time models (log-normal, log-logistic, Weibull) +- Continuous-time multi-state Markov model (CMM) +- Time-dependent covariate model (TDCM) +- Time-homogeneous hidden Markov model (THMM) +- Mixture cure and piecewise exponential models +- Competing risks generators (constant and Weibull hazards) +- Visualization utilities for simulated datasets +- Scikit-learn compatible data generator +- Conversion helpers for scikit-survival and lifelines +- Command-line interface and export utilities ---- +## Installation + +Requires Python 3.10 or later. -## πŸ“¦ Installation +Install the latest release from PyPI: ```bash -poetry install +pip install gen-surv ``` -## ✨ Features - -- Consistent interface across models -- Censoring support (`uniform` or `exponential`) -- Easy integration with `pandas` and `NumPy` -- Suitable for benchmarking survival algorithms and teaching -- Accelerated Failure Time (Log-Normal) model generator -- Command-line interface powered by `Typer` - -## πŸ§ͺ Example -```python -from gen_surv import generate +To develop locally with all extras: -# CPHM -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) +```bash +git clone https://github.com/DiogoRibeiro7/genSurvPy.git +cd genSurvPy +# Install runtime and development dependencies +# (scikit-survival is optional but required for integration tests). +# On Debian/Ubuntu you may need `build-essential gfortran libopenblas-dev` to +# build scikit-survival. +poetry install --with dev +``` -# AFT Log-Normal -generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) +Integration tests that rely on scikit-survival are automatically skipped if the package is not installed. -# CMM -generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, - qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) +## Development Setup -# TDCM -generate(model="tdcm", n=100, dist="weibull", corr=0.5, - dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, - beta=[0.1, 0.2, 0.3], lam=1.0) +Before committing changes, install the pre-commit hooks: -# THMM -generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], - emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, - p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) +```bash +pre-commit install +pre-commit run --all-files ``` -## ⌨️ Command-Line Usage +## Usage -Install the package and use ``python -m gen_surv`` to generate datasets without -writing Python code: +### Python API -```bash -python -m gen_surv dataset aft_ln --n 100 > data.csv -``` - -## πŸ”§ API Overview - -| Function | Description | -|----------|-------------| -| `generate()` | Unified interface that calls any generator | -| `gen_cphm()` | Cox Proportional Hazards Model | -| `gen_cmm()` | Continuous-Time Multi-State Markov Model | -| `gen_tdcm()` | Time-Dependent Covariate Model | -| `gen_thmm()` | Time-Homogeneous Markov Model | -| `gen_aft_log_normal()` | Accelerated Failure Time Log-Normal | -| `sample_bivariate_distribution()` | Sample correlated Weibull or exponential times | -| `runifcens()` | Generate uniform censoring times | -| `rexpocens()` | Generate exponential censoring times | - - -```text -genSurvPy/ -β”œβ”€β”€ gen_surv/ # Pacote principal -β”‚ β”œβ”€β”€ __main__.py # Interface CLI via python -m -β”‚ β”œβ”€β”€ cphm.py -β”‚ β”œβ”€β”€ cmm.py -β”‚ β”œβ”€β”€ tdcm.py -β”‚ β”œβ”€β”€ thmm.py -β”‚ β”œβ”€β”€ censoring.py -β”‚ β”œβ”€β”€ bivariate.py -β”‚ β”œβ”€β”€ validate.py -β”‚ └── interface.py -β”œβ”€β”€ tests/ # Testes automatizados -β”‚ β”œβ”€β”€ test_cphm.py -β”‚ β”œβ”€β”€ test_cmm.py -β”‚ β”œβ”€β”€ test_tdcm.py -β”‚ β”œβ”€β”€ test_thmm.py -β”œβ”€β”€ examples/ # Exemplos de uso -β”‚ β”œβ”€β”€ run_aft.py -β”‚ β”œβ”€β”€ run_cmm.py -β”‚ β”œβ”€β”€ run_cphm.py -β”‚ β”œβ”€β”€ run_tdcm.py -β”‚ └── run_thmm.py -β”œβ”€β”€ docs/ # DocumentaΓ§Γ£o Sphinx -β”‚ β”œβ”€β”€ source/ -β”‚ └── ... -β”œβ”€β”€ scripts/ # Utilidades diversas -β”‚ └── check_version_match.py -β”œβ”€β”€ tasks.py # Tarefas automatizadas com Invoke -β”œβ”€β”€ TODO.md # Roadmap de desenvolvimento -β”œβ”€β”€ pyproject.toml # Configurado com Poetry -β”œβ”€β”€ README.md -β”œβ”€β”€ LICENCE -└── .gitignore +```python +from gen_surv import generate, export_dataset, to_sksurv +from gen_surv.visualization import plot_survival_curve + +# basic Cox proportional hazards data +sim = generate( + model="cphm", + n=100, + beta=0.5, + covariate_range=2.0, + model_cens="uniform", + cens_par=1.0, +) + +plot_survival_curve(sim) +export_dataset(sim, "survival_data.rds") + +# convert for scikit-survival +sks_dataset = to_sksurv(sim) ``` -## 🧠 License +See the [usage guide](https://gensurvpy.readthedocs.io/en/latest/getting_started.html) for more examples. -MIT License. See [LICENCE](LICENCE) for details. +### Command Line +Datasets can be generated without writing Python code: -## πŸ”– Release Process - -This project uses Git tags to manage releases. A GitHub Actions workflow -(`version-check.yml`) verifies that the version declared in `pyproject.toml` -matches the latest Git tag. If they diverge, the workflow fails and prompts a -correction before merging. Run `python scripts/check_version_match.py` locally -before creating a tag to catch issues early. +```bash +python -m gen_surv dataset cphm --n 1000 -o survival.csv +``` -## 🌟 Code of Conduct +## Supported Models -Please read our [Code of Conduct](CODE_OF_CONDUCT.md) to learn about the -expectations for participants in this project. +| Model | Description | +|-------|-------------| +| **CPHM** | Cox proportional hazards | +| **AFT** | Accelerated failure time (log-normal, log-logistic, Weibull) | +| **CMM** | Continuous-time multi-state Markov | +| **TDCM** | Time-dependent covariates | +| **THMM** | Time-homogeneous hidden Markov | +| **Competing Risks** | Multiple event types with cause-specific hazards | +| **Mixture Cure** | Models long-term survivors | +| **Piecewise Exponential** | Flexible baseline hazard via intervals | -## 🀝 Contributing +More details on each algorithm are available in the [Algorithms](https://gensurvpy.readthedocs.io/en/latest/algorithms.html) page. For additional background, see the [theory guide](https://gensurvpy.readthedocs.io/en/latest/theory.html). -Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on setting up your environment, running tests, and submitting pull requests. +## Documentation -## πŸ”§ Development Tasks +Full documentation is hosted on [Read the Docs](https://gensurvpy.readthedocs.io/en/latest/). It includes installation instructions, tutorials, API references and a bibliography. -Common project commands are defined in [`tasks.py`](tasks.py) and can be executed with [Invoke](https://www.pyinvoke.org/): +To build the docs locally: ```bash -poetry run inv -l # list available tasks -poetry run inv test # run the test suite +cd docs +make html ``` -## πŸ“‘ Citation +Open `build/html/index.html` in your browser to view the result. + +## License + +This project is licensed under the MIT License. See [LICENSE](LICENSE) for details. + +## Citation -If you use **gen_surv** in your work, please cite it using the metadata in -[`CITATION.cff`](CITATION.cff). Many reference managers can import this file -directly. +If you use **gen_surv** in your research, please cite the project using the metadata in [CITATION.cff](CITATION.cff). ## Author diff --git a/TODO.md b/TODO.md index 91b4fe4..eb15799 100644 --- a/TODO.md +++ b/TODO.md @@ -1,103 +1,65 @@ -# TODO – Roadmap for gen_surv +# gen_surv Roadmap -This document outlines future enhancements, features, and ideas for improving the gen_surv package. +This document outlines the planned development priorities for future versions of gen_surv. This roadmap will be periodically updated based on user feedback, research needs, and community contributions. ---- +## Short-term Goals (v1.1.x) -## ✨ Priority Items +### Additional Statistical Models +- [ ] **Recurrent Events Model**: Generate data with multiple events per subject +- [ ] **Time-Varying Effects**: Support for non-proportional hazards with coefficients that change over time +- [ ] **Extended Competing Risks**: Allow for correlation between competing risks -- [βœ…] Add property-based tests using Hypothesis to cover edge cases -- [βœ…] Build a CLI for generating datasets from the terminal -- [ ] Expand documentation with multilingual support and more usage examples -- [ ] Implement Weibull and log-logistic AFT models and add visualization utilities -- [βœ…] Provide CITATION metadata for proper referencing -- [ ] Ensure all functions include Google-style docstrings with inline comments +### Visualization and Analysis +- [ ] **Enhanced Visualization Toolkit**: Add more plot types and customization options +- [ ] **Interactive Visualizations**: Add options using Plotly for interactive exploration +- [ ] **Data Quality Reports**: Generate reports on statistical properties of generated datasets ---- +### Usability Improvements +- [ ] **Dataset Catalog**: Pre-configured parameters to mimic classic survival datasets +- [ ] **Parameter Estimation**: Tools to estimate generation parameters from existing datasets +- [ ] **Extended CLI**: Add more command-line options for all models -## πŸ“¦ 1. Interface and UX +## Medium-term Goals (v1.2.x) -- [βœ…] Create a `generate(..., return_type="df" | "dict")` interface -- [βœ…] Add `__version__` using `importlib.metadata` or `poetry-dynamic-versioning` -- [βœ…] Build a CLI with `typer` or `click` -- [βœ…] Add example notebooks or scripts for each model (`examples/` folder) +### Advanced Statistical Models +- [ ] **Joint Longitudinal-Survival Models**: Generators for models that simultaneously handle longitudinal outcomes and time-to-event data +- [ ] **Frailty Models**: Support for shared and nested frailty models +- [ ] **Interval Censoring**: Support for interval-censored data generation ---- +### Technical Enhancements +- [ ] **Parallel Processing**: Multi-core support for faster generation of large datasets +- [ ] **Memory Optimization**: Streaming data generation for very large datasets +- [ ] **Performance Benchmarks**: Systematic benchmarking of data generation speed -## πŸ“š 2. Documentation +### Integration and Ecosystem +- [ ] **scikit-learn Extensions**: More scikit-learn compatible estimators and transformers +- [ ] **Stan/PyMC Integration**: Export data in formats suitable for Bayesian modeling +- [ ] **Dashboard**: Simple Streamlit app for data exploration and generation -- [βœ…] Add a "Model Comparison Guide" section (`index.md` + `theory.md`) -- [βœ…] Add "How It Works" sections for each model (`theory.md`) -- [βœ…] Include usage examples in index with real calls -- [ ] Optional: add multilingual docs using `sphinx-intl` +## Long-term Goals (v2.x) ---- +### Advanced Features +- [ ] **Bayesian Survival Models**: Generators for Bayesian survival analysis with various priors +- [ ] **Spatial Survival Models**: Generate survival data with spatial correlation +- [ ] **Survival Neural Networks**: Integration with deep learning approaches to survival analysis -## πŸ§ͺ 3. Testing and Quality +### Infrastructure and Performance +- [ ] **GPU Acceleration**: Optional GPU support for large dataset generation +- [ ] **JAX/Numba Implementation**: High-performance implementations of key algorithms +- [ ] **R Interface**: Create an R package that interfaces with gen_surv -- [βœ…] Add tests for each model (e.g., `test_tdcm.py`, `test_thmm.py`, `test_aft.py`) -- [βœ…] Add property-based tests with `hypothesis` -- [ ] Cover edge cases (e.g., invalid parameters, n=0, negative censoring) -- [ ] Run tests on multiple Python versions (CI matrix) +### Community and Documentation +- [ ] **Interactive Tutorials**: Using Jupyter Book or similar tools +- [ ] **Video Tutorials**: Short video demonstrations of key features +- [ ] **Case Studies**: Real-world examples showing how gen_surv can be used for teaching or research +- [ ] **User Showcase**: Gallery of research or teaching that uses gen_surv ---- +## How to Contribute -## 🧠 4. Advanced Models +We welcome contributions that help us achieve these roadmap goals! If you're interested in working on any of these features, please check the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines and open an issue to discuss your approach before submitting a pull request. -- [ ] Add Piecewise Exponential Model support -- [ ] Add competing risks / multi-event simulation -- [βœ…] Implement parametric AFT models (log-normal) -- [ ] Implement parametric AFT models (log-logistic, weibull) -- [ ] Simulate time-varying hazards -- [ ] Add informative or covariate-dependent censoring +For suggesting new features or modifications to this roadmap, please open an issue with the "enhancement" tag. ---- +## Version History -## πŸ“Š 5. Visualization and Analysis - -- [ ] Create `plot_survival(df, model=...)` utilities -- [ ] Create `describe_survival(df)` summary helpers -- [ ] Export data to CSV / JSON / Feather - ---- - -## 🌍 6. Ecosystem Integration - -- [ ] Add a `GenSurvDataGenerator` compatible with `sklearn` -- [ ] Enable use with `lifelines`, `scikit-survival`, `sksurv` -- [ ] Export in R-compatible formats (.csv, .rds) - ---- - -## πŸ” 7. Other Ideas - -- [ ] Add performance benchmarks for each model -- [βœ…] Improve PyPI discoverability (added tags, keywords, docs) -- [ ] Create a Streamlit or Gradio live demo - ---- - -## 🧠 8. New Survival Models to Implement - -- [βœ…] Log-Normal AFT -- [ ] Log-Logistic AFT -- [ ] Weibull AFT -- [ ] Piecewise Exponential -- [ ] Competing Risks -- [ ] Recurrent Events -- [ ] Mixture Cure Model - ---- - -## 🧬 9. Advanced Data Simulation Features - -- [ ] Recurrent events (multiple events per individual) -- [ ] Frailty models (random effects) -- [ ] Time-varying hazard functions -- [ ] Multi-line start-stop formatted data -- [ ] Competing risks with cause-specific hazards -- [ ] Simulate violations of PH assumption -- [ ] Grouped / clustered data generation -- [ ] Mixed covariates: categorical, continuous, binary -- [ ] Joint models (longitudinal + survival outcome) -- [ ] Controlled scenarios for robustness tests +For a detailed history of past releases, please see our [CHANGELOG.md](CHANGELOG.md). diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..aa3750b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,26 @@ +# Benchmarks + +Performance benchmarks for **genSurvPy** powered by the [`pytest-benchmark`](https://pytest-benchmark.readthedocs.io/en/latest/) plugin. + +## Running + +Install the optional `pytest-benchmark` dependency and execute: + +```bash +poetry run pytest benchmarks -q --benchmark-only +``` + +To run an individual benchmark module: + +```bash +poetry run pytest benchmarks/test_cmm_benchmark.py --benchmark-only +``` + +## Available benchmarks + +- Validation helpers (`test_validation_benchmark.py`) +- Time-dependent Cox model generation (`test_tdcm_benchmark.py`) +- Continuous-time Markov model generation (`test_cmm_benchmark.py`) +- Piecewise exponential generation (`test_piecewise_benchmark.py`) +- Cox proportional hazards model generation (`test_cphm_benchmark.py`) +- Mixture cure model generation (`test_mixture_benchmark.py`) diff --git a/benchmarks/test_cmm_benchmark.py b/benchmarks/test_cmm_benchmark.py new file mode 100644 index 0000000..a4f1aa9 --- /dev/null +++ b/benchmarks/test_cmm_benchmark.py @@ -0,0 +1,18 @@ +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.cmm import gen_cmm + + +def test_cmm_generation_benchmark(benchmark): + benchmark( + gen_cmm, + n=1000, + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + covariate_range=2.0, + rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + seed=42, + ) diff --git a/benchmarks/test_cphm_benchmark.py b/benchmarks/test_cphm_benchmark.py new file mode 100644 index 0000000..202f3c2 --- /dev/null +++ b/benchmarks/test_cphm_benchmark.py @@ -0,0 +1,17 @@ +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.cphm import gen_cphm + + +def test_cphm_generation_benchmark(benchmark): + benchmark( + gen_cphm, + n=1000, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + seed=42, + ) diff --git a/benchmarks/test_mixture_benchmark.py b/benchmarks/test_mixture_benchmark.py new file mode 100644 index 0000000..80438a0 --- /dev/null +++ b/benchmarks/test_mixture_benchmark.py @@ -0,0 +1,14 @@ +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.mixture import gen_mixture_cure + + +def test_mixture_generation_benchmark(benchmark): + benchmark( + gen_mixture_cure, + n=1000, + cure_fraction=0.3, + seed=42, + ) diff --git a/benchmarks/test_piecewise_benchmark.py b/benchmarks/test_piecewise_benchmark.py new file mode 100644 index 0000000..8e664af --- /dev/null +++ b/benchmarks/test_piecewise_benchmark.py @@ -0,0 +1,16 @@ +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.piecewise import gen_piecewise_exponential + + +def test_piecewise_generation_benchmark(benchmark): + benchmark( + gen_piecewise_exponential, + n=1000, + breakpoints=[1.0, 2.0], + hazard_rates=[0.5, 0.3, 0.1], + n_covariates=3, + seed=42, + ) diff --git a/benchmarks/test_tdcm_benchmark.py b/benchmarks/test_tdcm_benchmark.py new file mode 100644 index 0000000..9509548 --- /dev/null +++ b/benchmarks/test_tdcm_benchmark.py @@ -0,0 +1,19 @@ +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.tdcm import gen_tdcm + + +def test_tdcm_generation_benchmark(benchmark): + benchmark( + gen_tdcm, + n=1000, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) diff --git a/benchmarks/test_validation_benchmark.py b/benchmarks/test_validation_benchmark.py new file mode 100644 index 0000000..7cf50f9 --- /dev/null +++ b/benchmarks/test_validation_benchmark.py @@ -0,0 +1,11 @@ +import numpy as np +import pytest + +pytest.importorskip("pytest_benchmark") + +from gen_surv.validation import ensure_positive_sequence + + +def test_positive_sequence_benchmark(benchmark): + seq = np.random.rand(10000) + 1.0 + benchmark(ensure_positive_sequence, seq, "seq") diff --git a/binder/requirements.txt b/binder/requirements.txt new file mode 100644 index 0000000..2df98f0 --- /dev/null +++ b/binder/requirements.txt @@ -0,0 +1,2 @@ +-e . +jupyterlab diff --git a/docs/requirements.txt b/docs/requirements.txt index 98a3c62..a0990cf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,8 @@ -sphinx -myst-parser +sphinx>=6.0 +myst-parser>=1.0.0,<4.0.0 +sphinx-rtd-theme +sphinx-autodoc-typehints +sphinx-copybutton +sphinx-design +linkify-it-py>=2.0 +matplotlib diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css new file mode 100644 index 0000000..22f503a --- /dev/null +++ b/docs/source/_static/custom.css @@ -0,0 +1,56 @@ +/* Custom styling for gen_surv documentation */ + +/* Improve code block appearance */ +.highlight { + background: #f8f9fa; + border: 1px solid #e9ecef; + border-radius: 4px; + padding: 10px; + margin: 10px 0; +} + +/* Style admonitions */ +.admonition { + border-radius: 6px; + margin: 1em 0; + padding: 1em; +} + +.admonition.tip { + border-left: 4px solid #28a745; + background-color: #d4edda; +} + +.admonition.note { + border-left: 4px solid #007bff; + background-color: #cce5ff; +} + +.admonition.warning { + border-left: 4px solid #ffc107; + background-color: #fff3cd; +} + +/* Table styling */ +table.docutils { + border-collapse: collapse; + margin: 1em 0; + width: 100%; +} + +table.docutils th, +table.docutils td { + border: 1px solid #ddd; + padding: 8px; + text-align: left; +} + +table.docutils th { + background-color: #f8f9fa; + font-weight: bold; +} + +/* Math styling */ +.math { + font-size: 1.1em; +} diff --git a/docs/source/algorithms.md b/docs/source/algorithms.md new file mode 100644 index 0000000..d5bcb70 --- /dev/null +++ b/docs/source/algorithms.md @@ -0,0 +1,48 @@ +--- +orphan: true +--- + +# Algorithm Overview + +This page provides a short description of each model implemented in **gen_surv**. For mathematical details see {doc}`theory`. + +## Cox Proportional Hazards Model (CPHM) +The hazard at time $t$ is proportional to a baseline hazard multiplied by the exponential of covariate effects. +It is widely used for modelling relative risks under the proportional hazards assumption. +See {ref}`Cox1972` in the {doc}`bibliography` for the seminal paper. + +## Accelerated Failure Time Models (AFT) +These parametric models directly relate covariates to survival time. +gen_surv includes log-normal, log-logistic and Weibull variants allowing different baseline distributions. +They are convenient when the effect of covariates accelerates or decelerates event times. + +## Continuous-Time Multi-State Markov Model (CMM) +Transitions between states are governed by a generator matrix. +This model is suited for illness-death and other multi-state processes where state occupancy changes continuously over time. +The mathematical formulation follows the counting-process approach of Andersen et al. {ref}`Andersen1993`. + +## Time-Dependent Covariate Model (TDCM) +Extends the Cox model to covariates that vary during follow-up. +Covariates are simulated in a piecewise fashion with optional correlation across segments. + +## Time-Homogeneous Hidden Markov Model (THMM) +Handles processes with unobserved states that emit observable values. +The latent transitions follow a homogeneous Markov chain while emissions are Gaussian. +For background on these models see Zucchini et al. {ref}`Zucchini2017`. + +## Competing Risks +Allows multiple failure types with cause-specific hazards. +gen_surv supports constant and Weibull hazards for each cause. +The subdistribution approach of Fine and Gray {ref}`FineGray1999` is commonly used for analysis. + +## Mixture Cure Model +Assumes a proportion of individuals will never experience the event. +A logistic component determines who is cured, while uncured subjects follow an exponential failure distribution. +Mixture cure models were introduced by Farewell {ref}`Farewell1982`. + +## Piecewise Exponential Model +Approximates complex hazard shapes by dividing follow-up time into intervals with constant hazard within each interval. +This yields a flexible baseline hazard while remaining computationally simple. + +For additional reading on these methods please see the {doc}`bibliography`. + diff --git a/docs/source/api/index.md b/docs/source/api/index.md new file mode 100644 index 0000000..dc3c722 --- /dev/null +++ b/docs/source/api/index.md @@ -0,0 +1,89 @@ +--- +orphan: true +--- + +# API Reference + +Complete documentation for all gen_surv functions and classes. + +```{note} +The `to_sksurv` helper and related tests rely on the optional +dependency `scikit-survival`. Install it with `poetry install --with dev` +or `pip install scikit-survival` to enable this functionality. +``` + +## Core Interface + +::: gen_surv.interface + options: + members: true + undoc-members: true + show-inheritance: true + +## Model Generators + +### Cox Proportional Hazards Model +::: gen_surv.cphm + options: + members: true + undoc-members: true + show-inheritance: true + +### Accelerated Failure Time Models +::: gen_surv.aft + options: + members: true + undoc-members: true + show-inheritance: true + +### Continuous-Time Markov Models +::: gen_surv.cmm + options: + members: true + undoc-members: true + show-inheritance: true + +### Time-Dependent Covariate Models +::: gen_surv.tdcm + options: + members: true + undoc-members: true + show-inheritance: true + +### Time-Homogeneous Markov Models +::: gen_surv.thmm + options: + members: true + undoc-members: true + show-inheritance: true + +## Utility Functions + +### Censoring Functions +::: gen_surv.censoring + options: + members: true + undoc-members: true + show-inheritance: true + +### Bivariate Distributions +::: gen_surv.bivariate + options: + members: true + undoc-members: true + show-inheritance: true + +### Validation Functions +::: gen_surv.validation + options: + members: true + undoc-members: true + show-inheritance: true + +### Command Line Interface +::: gen_surv.cli + options: + members: true + undoc-members: true + show-inheritance: true + diff --git a/docs/source/bibliography.md b/docs/source/bibliography.md new file mode 100644 index 0000000..5a5285a --- /dev/null +++ b/docs/source/bibliography.md @@ -0,0 +1,59 @@ +--- +orphan: true +--- + +# References + +Below is a selection of references covering the statistical models implemented in **gen_surv**. + +(Cox1972)= +## Cox (1972) +Cox, D. R. (1972). Regression Models and Life-Tables. *Journal of the Royal Statistical Society: Series B*, 34(2), 187-220. + +(Farewell1982)= +## Farewell (1982) +Farewell, V.T. (1982). The Use of Mixture Models for the Analysis of Survival Data with Long-Term Survivors. *Biometrics*, 38(4), 1041-1046. + +(FineGray1999)= +## Fine and Gray (1999) +Fine, J.P., & Gray, R.J. (1999). A Proportional Hazards Model for the Subdistribution of a Competing Risk. *Journal of the American Statistical Association*, 94(446), 496-509. + +(Andersen1993)= +## Andersen et al. (1993) +Andersen, P.K., Borgan, Ø., Gill, R.D., & Keiding, N. (1993). *Statistical Models Based on Counting Processes*. Springer. + +(Zucchini2017)= +## Zucchini et al. (2017) +Zucchini, W., MacDonald, I.L., & Langrock, R. (2017). *Hidden Markov Models for Time Series*. Chapman and Hall/CRC. + +(KleinMoeschberger2003)= +## Klein and Moeschberger (2003) +Klein, J.P., & Moeschberger, M.L. (2003). *Survival Analysis: Techniques for Censored and Truncated Data*. Springer. + +(KalbfleischPrentice2002)= +## Kalbfleisch and Prentice (2002) +Kalbfleisch, J.D., & Prentice, R.L. (2002). *The Statistical Analysis of Failure Time Data*. Wiley. + +(CookLawless2007)= +## Cook and Lawless (2007) +Cook, R.J., & Lawless, J.F. (2007). *The Statistical Analysis of Recurrent Events*. Springer. + +(KaplanMeier1958)= +## Kaplan and Meier (1958) +Kaplan, E.L., & Meier, P. (1958). Nonparametric Estimation from Incomplete Observations. *Journal of the American Statistical Association*, 53(282), 457-481. +(TherneauGrambsch2000)= +## Therneau and Grambsch (2000) +Therneau, T.M., & Grambsch, P.M. (2000). *Modeling Survival Data: Extending the Cox Model*. Springer. + +(FlemingHarrington1991)= +## Fleming and Harrington (1991) +Fleming, T.R., & Harrington, D.P. (1991). *Counting Processes and Survival Analysis*. Wiley. + +(Collett2015)= +## Collett (2015) +Collett, D. (2015). *Modelling Survival Data in Medical Research*. CRC Press. + +(KleinbaumKlein2012)= +## Kleinbaum and Klein (2012) +Kleinbaum, D.G., & Klein, M. (2012). *Survival Analysis: A Self-Learning Text*. Springer. + diff --git a/docs/source/changelog.md b/docs/source/changelog.md new file mode 100644 index 0000000..011713d --- /dev/null +++ b/docs/source/changelog.md @@ -0,0 +1,9 @@ +--- +orphan: true +--- + +# Changelog + +```{include} ../../CHANGELOG.md +:relative-docs: true +``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 6ac6534..df8af80 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,42 +1,92 @@ -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Project information ----------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os -import sys -sys.path.insert(0, os.path.abspath('../../gen_surv')) - -project = 'gen_surv' -copyright = '2025, Diogo Ribeiro' -author = 'Diogo Ribeiro' -release = '1.0.3' +from datetime import datetime +from importlib import metadata -# -- General configuration --------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration +# Project information +project = "gen_surv" +copyright = f"{datetime.now().year}, Diogo Ribeiro" +author = "Diogo Ribeiro" +release = metadata.version("gen_surv") +version = release +# General configuration extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", - "myst_parser", "sphinx.ext.viewcode", - "sphinx.ext.autosectionlabel", + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.githubpages", + "myst_parser", + "sphinx_copybutton", + "sphinx_design", + "sphinx_autodoc_typehints", ] -autosectionlabel_prefix_document = True +# MyST Parser configuration +myst_enable_extensions = [ + "colon_fence", + "deflist", + "html_admonition", + "html_image", + "linkify", + "replacements", + "smartquotes", + "substitution", + "tasklist", +] + +# Autodoc configuration +autodoc_default_options = { + "members": True, + "member-order": "bysource", + "special-members": "__init__", + "undoc-members": True, + "exclude-members": "__weakref__", +} + +# Autosummary +autosummary_generate = True + +# Napoleon settings +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False -# Point to index.md or index.rst as the root document -master_doc = "index" +# Intersphinx mapping +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), +} -templates_path = ['_templates'] -exclude_patterns = [] +# Disable fetching remote inventories when network access is unavailable +if os.environ.get("SKIP_INTERSPHINX", "1") == "1": + intersphinx_mapping = {} -# -- Options for HTML output ------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output +# HTML theme options +html_theme = "sphinx_rtd_theme" +html_theme_options = { + "canonical_url": "https://gensurvpy.readthedocs.io/", + "analytics_id": "", + "logo_only": False, + "prev_next_buttons_location": "bottom", + "style_external_links": False, + "style_nav_header_background": "#2980B9", + "collapse_navigation": True, + "sticky_navigation": True, + "navigation_depth": 4, + "includehidden": True, + "titles_only": False, +} -html_theme = 'alabaster' -html_static_path = ['_static'] +html_static_path = ["_static"] +html_css_files = ["custom.css"] +# Output file base name for HTML help builder +htmlhelp_basename = "gensurvdoc" +# Copy button configuration +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +copybutton_prompt_is_regexp = True diff --git a/docs/source/contributing.md b/docs/source/contributing.md new file mode 100644 index 0000000..82121b3 --- /dev/null +++ b/docs/source/contributing.md @@ -0,0 +1,9 @@ +--- +orphan: true +--- + +# Contributing + +```{include} ../../CONTRIBUTING.md +:relative-docs: true +``` diff --git a/docs/source/examples/cmm.md b/docs/source/examples/cmm.md new file mode 100644 index 0000000..a389544 --- /dev/null +++ b/docs/source/examples/cmm.md @@ -0,0 +1,28 @@ +# Continuous-Time Multi-State Markov Model (CMM) + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/cmm.ipynb) + +Visualize transition times from the CMM generator: + +```python +import numpy as np +import matplotlib.pyplot as plt +from gen_surv import generate + +np.random.seed(0) + +df = generate( + model="cmm", + n=200, + model_cens="exponential", + cens_par=2.0, + beta=[0.1, 0.2, 0.3], + covariate_range=1.0, + rate=[0.1, 1.0, 0.2, 1.0, 0.1, 1.0], +) + +plt.hist(df["stop"], bins=20, color="#4C72B0") +plt.xlabel("Time") +plt.ylabel("Frequency") +plt.title("CMM Transition Times") +``` diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md new file mode 100644 index 0000000..14003cf --- /dev/null +++ b/docs/source/examples/index.md @@ -0,0 +1,25 @@ +--- +orphan: true +--- + +# Examples + +Real-world examples and use cases for gen_surv. + +```{note} +Some examples may require optional packages such as +`scikit-survival`. Install them with `poetry install --with dev` or +`pip install scikit-survival` before running these examples. +``` + +Each example page includes a [Binder](https://mybinder.org/) badge so you can +launch the corresponding notebook and experiment interactively in your +browser. + +```{toctree} +:maxdepth: 2 + +tdcm +cmm +thmm +``` diff --git a/docs/source/examples/tdcm.md b/docs/source/examples/tdcm.md new file mode 100644 index 0000000..cc027af --- /dev/null +++ b/docs/source/examples/tdcm.md @@ -0,0 +1,30 @@ +# Time-Dependent Covariate Model (TDCM) + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/tdcm.ipynb) + +A basic visualization of event times produced by the TDCM generator: + +```python +import numpy as np +import matplotlib.pyplot as plt +from gen_surv import generate + +np.random.seed(0) + +df = generate( + model="tdcm", + n=200, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, +) + +plt.hist(df["stop"], bins=20, color="#4C72B0") +plt.xlabel("Time") +plt.ylabel("Frequency") +plt.title("TDCM Event Times") +``` diff --git a/docs/source/examples/thmm.md b/docs/source/examples/thmm.md new file mode 100644 index 0000000..01fb17f --- /dev/null +++ b/docs/source/examples/thmm.md @@ -0,0 +1,28 @@ +# Time-Homogeneous Hidden Markov Model (THMM) + +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/thmm.ipynb) + +An example of event times generated by the THMM: + +```python +import numpy as np +import matplotlib.pyplot as plt +from gen_surv import generate + +np.random.seed(0) + +df = generate( + model="thmm", + n=200, + model_cens="exponential", + cens_par=3.0, + beta=[0.1, 0.2, 0.3], + covariate_range=1.0, + rate=[0.2, 0.1, 0.3], +) + +plt.hist(df["time"], bins=20, color="#4C72B0") +plt.xlabel("Time") +plt.ylabel("Frequency") +plt.title("THMM Event Times") +``` diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md new file mode 100644 index 0000000..7788093 --- /dev/null +++ b/docs/source/getting_started.md @@ -0,0 +1,88 @@ +--- +orphan: true +--- + +# Getting Started + +This guide will help you install gen_surv and generate your first survival dataset. + +## Installation + +### From PyPI (Recommended) + +```bash +pip install gen-surv +``` + +### From Source + +```bash +git clone https://github.com/DiogoRibeiro7/genSurvPy.git +cd genSurvPy +poetry install +``` + +```{note} +Some features and tests rely on optional packages such as +`scikit-survival`. Install them with `poetry install --with dev` or +`pip install scikit-survival` (additional system libraries may be +required). +``` + +## Basic Usage + +The main entry point is the `generate()` function: + +```python +from gen_surv import generate + +# Generate Cox proportional hazards data + df = generate( + model="cphm", # Model type + n=100, # Sample size + beta=0.5, # Covariate effect + covariate_range=2.0, # Covariate range + model_cens="uniform", # Censoring type + cens_par=3.0 # Censoring parameter + ) + +print(df.head()) +``` + +## Understanding the Output + +All models return a pandas DataFrame with at least these columns: + +- `time`: Observed event or censoring time +- `status`: Event indicator (1 = event, 0 = censored) +- Additional columns depend on the specific model + +## Command Line Usage + +Generate datasets directly from the terminal: + +```bash +# Generate CPHM data and save to CSV +python -m gen_surv dataset cphm --n 1000 -o survival_data.csv + +# Print AFT data to stdout +python -m gen_surv dataset aft_ln --n 500 +``` + +## Next Steps + +- Explore the {doc}`tutorials/index` for detailed examples +- Check the {doc}`api/index` for complete function documentation +- Read about the {doc}`theory` behind each model + +## Building the Documentation + +To preview the documentation locally run: + +```bash +cd docs +make html +``` + +More details about our Read the Docs configuration can be found in {doc}`rtd`. + diff --git a/docs/source/index.md b/docs/source/index.md index 5a22562..d3f5ce1 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,100 +1,142 @@ -# gen_surv +# gen_surv: Survival Data Simulation in Python -**gen_surv** is a Python package for simulating survival data under various models, inspired by the R package `genSurv`. +[![Documentation Status](https://readthedocs.org/projects/gensurvpy/badge/?version=latest)](https://gensurvpy.readthedocs.io/en/latest/?badge=latest) +[![PyPI version](https://badge.fury.io/py/gen-surv.svg)](https://badge.fury.io/py/gen-surv) +[![Python versions](https://img.shields.io/pypi/pyversions/gen-surv.svg)](https://pypi.org/project/gen-surv/) -It includes generators for: +**gen_surv** is a comprehensive Python package for simulating survival data under various statistical models, inspired by the R package `genSurv`. It provides a unified interface for generating synthetic survival datasets that are essential for: -- **Cox Proportional Hazards Models (CPHM)** -- **Continuous-Time Markov Models (CMM)** -- **Time-Dependent Covariate Models (TDCM)** -- **Time-Homogeneous Hidden Markov Models (THMM)** -- **Accelerated Failure Time (AFT) Log-Normal Models** +- **Research**: Testing new survival analysis methods +- **Education**: Teaching survival analysis concepts +- **Benchmarking**: Comparing different survival models +- **Validation**: Testing statistical software implementations -Key functions include `generate()`, `gen_cphm()`, `gen_cmm()`, `gen_tdcm()`, -`gen_thmm()`, `gen_aft_log_normal()`, `sample_bivariate_distribution()`, -`runifcens()`, and `rexpocens()`. +```{admonition} Quick Start +:class: tip ---- - -See the [Getting Started](usage) guide for installation instructions. - -## πŸ“š Modules - -```{toctree} -:maxdepth: 2 -:caption: Contents +Install with pip: +```bash +pip install gen-surv +``` -usage -modules -theory +Generate your first dataset: +```python +from gen_surv import generate +df = generate(model="cphm", n=100, beta=0.5, covariate_range=2.0) +``` ``` +```{note} +The `to_sksurv` helper and related tests require the optional +dependency `scikit-survival`. Install it with `poetry install --with dev` +or `pip install scikit-survival` if you need this functionality. +``` -# πŸš€ Usage Example +## Supported Models -```python -from gen_surv import generate +| Model | Description | Use Case | +|-------|-------------|----------| +| **CPHM** | Cox Proportional Hazards | Standard survival regression | +| **AFT** | Accelerated Failure Time | Non-proportional hazards | +| **CMM** | Continuous-Time Markov | Multi-state processes | +| **TDCM** | Time-Dependent Covariates | Dynamic risk factors | +| **THMM** | Time-Homogeneous Markov | Hidden state processes | +| **Competing Risks** | Multiple event types | Cause-specific hazards | +| **Mixture Cure** | Long-term survivors | Logistic cure fraction | +| **Piecewise Exponential** | Piecewise constant hazard | Flexible baseline | -# CPHM -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) +## Algorithm Descriptions -# AFT Log-Normal -generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) +For a brief summary of each statistical model see {doc}`algorithms`. Mathematical +details and notation are provided on the {doc}`theory` page. -# CMM -generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, - qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) +## Documentation Contents -# TDCM -generate(model="tdcm", n=100, dist="weibull", corr=0.5, - dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, - beta=[0.1, 0.2, 0.3], lam=1.0) +```{toctree} +:maxdepth: 2 -# THMM -generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], - emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, - p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) +getting_started +tutorials/index +api/index +theory +algorithms +examples/index +troubleshooting +rtd +contributing +changelog +bibliography ``` -## ⌨️ Command-Line Usage - -Generate datasets directly from the terminal: +## Quick Examples -```bash -python -m gen_surv dataset aft_ln --n 100 > data.csv +### Cox Proportional Hazards Model +```python +import gen_surv as gs + +# Basic CPHM with uniform censoring +df = gs.generate( + model="cphm", + n=500, + beta=0.5, + covariate_range=2.0, + model_cens="uniform", + cens_par=3.0 +) ``` -## Repository Layout - -```text -genSurvPy/ -β”œβ”€β”€ gen_surv/ -β”‚ └── ... -β”œβ”€β”€ tests/ -β”œβ”€β”€ examples/ -β”œβ”€β”€ docs/ -β”œβ”€β”€ scripts/ -β”œβ”€β”€ tasks.py -└── TODO.md +### Accelerated Failure Time Model +```python +# AFT with log-normal distribution +df = gs.generate( + model="aft_ln", + n=200, + beta=[0.5, -0.3, 0.2], + sigma=1.0, + model_cens="exponential", + cens_par=2.0 +) ``` -## πŸ”— Project Links +### Multi-State Markov Model +```python +# Three-state illness-death model +df = gs.generate( + model="cmm", + n=300, + qmat=[[0, 0.1], [0.05, 0]], + p0=[1.0, 0.0], + model_cens="uniform", + cens_par=5.0 +) +``` -- [Source Code](https://github.com/DiogoRibeiro7/genSurvPy) -- [License](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENCE) -- [Code of Conduct](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/CODE_OF_CONDUCT.md) +## Key Features +- **Unified Interface**: Single `generate()` function for all models +- **Flexible Censoring**: Support for uniform and exponential censoring +- **Rich Parameterization**: Extensive customization options +- **Command-Line Interface**: Generate datasets from terminal +- **Comprehensive Validation**: Input parameter checking +- **Educational Focus**: Clear mathematical documentation ## Citation -If you use **gen_surv** in your work, please cite it using the metadata in -[CITATION.cff](../../CITATION.cff). +If you use gen_surv in your research, please cite: -## Author +```bibtex +@software{ribeiro2025gensurvpy, + title = {gen_surv: Survival Data Simulation in Python}, + author = {Diogo Ribeiro}, + year = {2025}, + url = {https://github.com/DiogoRibeiro7/genSurvPy}, + version = {1.0.9} +} +``` -**Diogo Ribeiro** β€” [ESMAD - Instituto PolitΓ©cnico do Porto](https://esmad.ipp.pt) +## License -- ORCID: -- Professional email: -- Personal email: +MIT License - see [LICENSE](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENSE) for details. +For foundational papers related to these models see the {doc}`bibliography`. +Information on building the docs is provided in the {doc}`rtd` page. diff --git a/docs/source/modules.md b/docs/source/modules.md index 114a344..f90396b 100644 --- a/docs/source/modules.md +++ b/docs/source/modules.md @@ -1,51 +1,66 @@ +--- +orphan: true +--- + # API Reference ::: gen_surv.cphm options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.cmm options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.tdcm options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.thmm options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.interface options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.aft options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.bivariate options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.censoring options: members: true undoc-members: true + show-inheritance: true ::: gen_surv.cli options: members: true undoc-members: true + show-inheritance: true -::: gen_surv.validate +::: gen_surv.validation options: members: true undoc-members: true + show-inheritance: true + diff --git a/docs/source/rtd.md b/docs/source/rtd.md new file mode 100644 index 0000000..7eac430 --- /dev/null +++ b/docs/source/rtd.md @@ -0,0 +1,20 @@ +--- +orphan: true +--- + +# Read the Docs + +This project uses [Read the Docs](https://readthedocs.org/) to host its documentation. The site is automatically rebuilt whenever changes are pushed to the repository. + +Our build configuration is defined in `.readthedocs.yml`. It installs the package with the `docs` dependency group and builds the Sphinx docs using Python 3.11. + +## Building Locally + +To preview the documentation on your machine, run: + +```bash +cd docs +make html +``` + +Open `build/html/index.html` in your browser to view the result. diff --git a/docs/source/theory.md b/docs/source/theory.md index 957dce1..1101034 100644 --- a/docs/source/theory.md +++ b/docs/source/theory.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # πŸ“˜ Mathematical Foundations of `gen_surv` This page presents the mathematical formulation behind the survival models implemented in the `gen_surv` package. @@ -6,7 +10,9 @@ This page presents the mathematical formulation behind the survival models imple ## 1. Cox Proportional Hazards Model (CPHM) -The hazard function conditioned on covariates is: +This semi-parametric approach models the hazard as a baseline component +multiplied by an exponential term involving the covariates. The hazard +function conditioned on covariates is: $$ h(t \mid X) = h_0(t) \exp(X \\beta) @@ -39,7 +45,8 @@ $$ ## 2. Time-Dependent Covariate Model (TDCM) -A generalization of CPHM where covariates change over time: +This extension of the Cox model allows covariate values to vary during +follow-up, accommodating exposures or treatments that change over time: $$ h(t \mid Z(t)) = h_0(t) \\exp(Z(t) \\beta) @@ -51,7 +58,8 @@ In this package, piecewise covariate values are simulated with dependence across ## 3. Continuous-Time Multi-State Markov Model (CMM) -Markov model with generator matrix \( Q \). The transition probability matrix is given by: +This framework captures transitions between a finite set of states where +waiting times are exponentially distributed. With generator matrix \( Q \), the transition probability matrix is given by: $$ P(t) = \\exp(Qt) @@ -66,7 +74,9 @@ Where: ## 4. Time-Homogeneous Hidden Markov Model (THMM) -This model simulates observed states with unobserved latent state transitions. +This model handles situations where the process evolves through unobserved +states that generate the observed responses. It simulates observed states with +latent transitions. Let: @@ -84,7 +94,9 @@ $$ ## 5. Accelerated Failure Time (AFT) Models -AFT models assume that the effect of covariates accelerates or decelerates time to event directly, rather than the hazard. +These fully parametric models relate covariates to the logarithm of the +survival time. They assume the effect of a covariate speeds up or slows down the +event time directly, rather than acting on the hazard. ### Log-Normal AFT @@ -116,5 +128,25 @@ This model is especially useful when the proportional hazards assumption is not All models support censoring: -- **Uniform:** \( C_i \\sim U(0, \\text{cens\\_par}) \) -- **Exponential:** \( C_i \\sim \\text{Exp}(\\text{cens\\_par}) \) +- **Uniform:** \( C_i \sim U(0, \text{cens\_par}) \) +- **Exponential:** \( C_i \sim \text{Exp}(\text{cens\_par}) \) + +## 6. Competing Risks Models + +These models handle scenarios where several distinct failure types can occur. +Each cause has its own hazard function, and the observed status indicates which +event occurred (1, 2, ...). The package includes constant-hazard and +Weibull-hazard versions. + +## 7. Mixture Cure Models + +These models posit that a subset of the population is cured and will never +experience the event of interest. The generator mixes a logistic cure component +with an exponential hazard for the uncured, returning a ``cured`` indicator +column alongside the usual time and status. + +## 8. Piecewise Exponential Model + +Here the baseline hazard is assumed constant within each of several +user-specified intervals. This allows flexible hazard shapes over time while +remaining easy to simulate. diff --git a/docs/source/troubleshooting.md b/docs/source/troubleshooting.md new file mode 100644 index 0000000..767de20 --- /dev/null +++ b/docs/source/troubleshooting.md @@ -0,0 +1,36 @@ +# Troubleshooting + +Common issues and how to resolve them when using gen_surv. + +## ModuleNotFoundError: No module named 'gen_surv' + +Ensure the package is installed. If you're running from source, install in editable mode: + +```bash +pip install -e . +``` + +or with Poetry: + +```bash +poetry install +``` + +## Validation errors when generating data + +Many generators validate their inputs and raise ``ValidationError`` with +context such as ``while validating inputs for model 'cphm'``. Verify that +numeric parameters are within the allowed ranges and that sequences have the +correct length. + +## Inconsistent results between runs + +Most generators accept a ``seed`` parameter. Set it for reproducibility: + +```python +from gen_surv import generate + +df = generate(model="cphm", n=100, beta=0.5, covariate_range=2.0, + model_cens="uniform", cens_par=1.0, seed=42) +``` + diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md new file mode 100644 index 0000000..434e600 --- /dev/null +++ b/docs/source/tutorials/basic_usage.md @@ -0,0 +1,130 @@ +# Basic Usage Tutorial + +This tutorial covers the fundamentals of generating survival data with gen_surv. + +## Your First Dataset + +Let's start with the simplest case - generating data from a Cox proportional hazards model: + +```python +from gen_surv import generate +import pandas as pd + +# Generate basic CPHM data + df = generate( + model="cphm", + n=200, + beta=0.7, + covariate_range=1.5, + model_cens="exponential", + cens_par=2.0, + seed=42 # For reproducibility + ) + +print(f"Dataset shape: {df.shape}") +print(f"Event rate: {df['status'].mean():.2%}") +print("\nFirst 5 rows:") +print(df.head()) +``` + +## Understanding Parameters + +### Common Parameters + +All models share these parameters: + +- `n`: Sample size (number of individuals) +- `model_cens`: Censoring type ("uniform" or "exponential") +- `cens_par`: Censoring distribution parameter +- `seed`: Random seed for reproducibility + +### Model-Specific Parameters + +Each model has unique parameters. For CPHM: + +- `beta`: Covariate effect (hazard ratio = exp(beta)) +- `covariate_range`: Range for uniform covariate generation [0, covariate_range] + +## Censoring Mechanisms + +gen_surv supports two censoring types: + +### Uniform Censoring +```python +# Censoring times uniformly distributed on [0, cens_par] +df_uniform = generate( + model="cphm", + n=100, + beta=0.5, + covariate_range=2.0, + model_cens="uniform", + cens_par=3.0 +) +``` + +### Exponential Censoring +```python +# Censoring times exponentially distributed with mean cens_par +df_exponential = generate( + model="cphm", + n=100, + beta=0.5, + covariate_range=2.0, + model_cens="exponential", + cens_par=2.0 +) +``` + +## Exploring Your Data + +Basic data exploration: + +```python +import matplotlib.pyplot as plt +import seaborn as sns + +# Event rate by covariate level +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + +# Histogram of survival times +ax1.hist(df['time'], bins=20, alpha=0.7, edgecolor='black') +ax1.set_xlabel('Time') +ax1.set_ylabel('Frequency') +ax1.set_title('Distribution of Observed Times') + +# Event rate vs covariate +df['covariate_bin'] = pd.cut(df['covariate'], bins=5) +event_rate = df.groupby('covariate_bin')['status'].mean() +event_rate.plot(kind='bar', ax=ax2, rot=45) +ax2.set_ylabel('Event Rate') +ax2.set_title('Event Rate by Covariate Level') + +plt.tight_layout() +plt.show() +``` + +## Additional Example: Mixture Cure Model + +The mixture cure model separates subjects into cured and susceptible groups. +Here's how to simulate data using this model: + +```python +from gen_surv import generate + +df_mixture = generate( + model="mixture_cure", + n=200, + cure_fraction=0.3, + betas_survival=[0.8, -0.4], + betas_cure=[-0.6, 0.2], + seed=123, +) + +print(df_mixture[["time", "status", "cured"]].head()) +``` + +## Next Steps + +- Try different models (model_comparison) +- Learn advanced features (advanced_features) +- See integration examples (integration_examples) diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md new file mode 100644 index 0000000..bc3af37 --- /dev/null +++ b/docs/source/tutorials/index.md @@ -0,0 +1,13 @@ +--- +orphan: true +--- + +# Tutorials + +Step-by-step guides for using gen_surv effectively. + +```{toctree} +:maxdepth: 2 + +basic_usage +``` diff --git a/docs/source/usage.md b/docs/source/usage.md index e3f856e..69afc64 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # Getting Started This page offers a quick introduction to installing and using **gen_surv**. @@ -17,10 +21,20 @@ This will create a virtual environment and install all required packages. Generate datasets directly in Python: ```python -from gen_surv import generate +from gen_surv import export_dataset, generate # Cox Proportional Hazards example -generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) +df = generate( + model="cphm", + n=100, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, +) + +# Save to RDS for use in R +export_dataset(df, "simulated_data.rds") ``` You can also generate data from the command line: @@ -31,3 +45,45 @@ python -m gen_surv dataset aft_ln --n 100 > data.csv For a full description of available models and parameters, see the API reference. + +## Building the Documentation + +Documentation is written using [Sphinx](https://www.sphinx-doc.org). To build the HTML pages locally run: + +```bash +cd docs +make html +``` + +The generated files will be available under `docs/build/html`. + +## Scikit-learn Integration + +You can wrap the generator in a transformer compatible with scikit-learn: + +```python +from gen_surv import GenSurvDataGenerator + +est = GenSurvDataGenerator("cphm", n=10, beta=0.5, covariate_range=1.0) +df = est.fit_transform() +``` + +## Lifelines and scikit-survival + +Datasets generated with **gen_surv** can be directly used with +[lifelines](https://lifelines.readthedocs.io). For +[scikit-survival](https://scikit-survival.readthedocs.io) you can convert the +DataFrame using ``to_sksurv``: + +```{note} +The ``to_sksurv`` helper requires the optional dependency +``scikit-survival``. Install it with `poetry install --with dev` or +``pip install scikit-survival``. +``` + +```python +from gen_surv import to_sksurv + +struct = to_sksurv(df) +``` + diff --git a/examples/notebooks/cmm.ipynb b/examples/notebooks/cmm.ipynb new file mode 100644 index 0000000..303b3ff --- /dev/null +++ b/examples/notebooks/cmm.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from gen_surv import generate\n", + "\n", + "np.random.seed(0)\n", + "df = generate(\n", + " model=\"cmm\",\n", + " n=200,\n", + " model_cens=\"exponential\",\n", + " cens_par=2.0,\n", + " beta=[0.1, 0.2, 0.3],\n", + " covariate_range=1.0,\n", + " rate=[0.1, 1.0, 0.2, 1.0, 0.1, 1.0],\n", + ")\n", + "df.head()\n", + "plt.hist(df['stop'], bins=20, color='#4C72B0')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Frequency')\n", + "plt.title('CMM Transition Times')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/tdcm.ipynb b/examples/notebooks/tdcm.ipynb new file mode 100644 index 0000000..0f9a91f --- /dev/null +++ b/examples/notebooks/tdcm.ipynb @@ -0,0 +1,47 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from gen_surv import generate\n", + "\n", + "np.random.seed(0)\n", + "df = generate(\n", + " model=\"tdcm\",\n", + " n=200,\n", + " dist=\"weibull\",\n", + " corr=0.5,\n", + " dist_par=[1, 2, 1, 2],\n", + " model_cens=\"uniform\",\n", + " cens_par=1.0,\n", + " beta=[0.1, 0.2, 0.3],\n", + " lam=1.0,\n", + ")\n", + "df.head()\n", + "plt.hist(df['stop'], bins=20, color='#4C72B0')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Frequency')\n", + "plt.title('TDCM Event Times')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/thmm.ipynb b/examples/notebooks/thmm.ipynb new file mode 100644 index 0000000..c46463c --- /dev/null +++ b/examples/notebooks/thmm.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from gen_surv import generate\n", + "\n", + "np.random.seed(0)\n", + "df = generate(\n", + " model=\"thmm\",\n", + " n=200,\n", + " model_cens=\"exponential\",\n", + " cens_par=3.0,\n", + " beta=[0.1, 0.2, 0.3],\n", + " covariate_range=1.0,\n", + " rate=[0.2, 0.1, 0.3],\n", + ")\n", + "df.head()\n", + "plt.hist(df['time'], bins=20, color='#4C72B0')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Frequency')\n", + "plt.title('THMM Event Times')\n", + "plt.show()\n" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/run_aft.py b/examples/run_aft.py index 5ba6388..c2be370 100644 --- a/examples/run_aft.py +++ b/examples/run_aft.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.interface import generate # Generate synthetic survival data using Log-Normal AFT model @@ -11,7 +12,7 @@ sigma=1.0, model_cens="exponential", cens_par=3.0, - seed=123 + seed=123, ) print(df.head()) diff --git a/examples/run_aft_weibull.py b/examples/run_aft_weibull.py new file mode 100644 index 0000000..b213472 --- /dev/null +++ b/examples/run_aft_weibull.py @@ -0,0 +1,94 @@ +""" +Example demonstrating Weibull AFT model and visualization capabilities. +""" + +import os +import sys + +import matplotlib.pyplot as plt +import pandas as pd + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from gen_surv import generate +from gen_surv.visualization import ( + describe_survival, + plot_covariate_effect, + plot_hazard_comparison, + plot_survival_curve, +) + +# 1. Generate data from different models for comparison +models = { + "Weibull AFT (shape=0.5)": generate( + model="aft_weibull", + n=200, + beta=[0.5, -0.3], + shape=0.5, # Decreasing hazard + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ), + "Weibull AFT (shape=1.0)": generate( + model="aft_weibull", + n=200, + beta=[0.5, -0.3], + shape=1.0, # Constant hazard + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ), + "Weibull AFT (shape=2.0)": generate( + model="aft_weibull", + n=200, + beta=[0.5, -0.3], + shape=2.0, # Increasing hazard + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ), +} + +# Print sample data +print("Sample data from Weibull AFT model (shape=2.0):") +print(models["Weibull AFT (shape=2.0)"].head()) +print("\n") + +# 2. Compare survival curves from different models +fig1, ax1 = plot_survival_curve( + data=pd.concat([df.assign(_model=name) for name, df in models.items()]), + group_col="_model", + title="Comparing Survival Curves with Different Weibull Shapes", +) +plt.savefig("survival_curve_comparison.png", dpi=300, bbox_inches="tight") + +# 3. Compare hazard functions +fig2, ax2 = plot_hazard_comparison( + models=models, title="Comparing Hazard Functions with Different Weibull Shapes" +) +plt.savefig("hazard_comparison.png", dpi=300, bbox_inches="tight") + +# 4. Visualize covariate effect on survival +fig3, ax3 = plot_covariate_effect( + data=models["Weibull AFT (shape=2.0)"], + covariate_col="X0", + n_groups=3, + title="Effect of X0 Covariate on Survival", +) +plt.savefig("covariate_effect.png", dpi=300, bbox_inches="tight") + +# 5. Summary statistics +for name, df in models.items(): + print(f"Summary for {name}:") + summary = describe_survival(df) + print(summary) + print("\n") + +print("Plots saved to current directory.") + +# Show plots if running interactively +if __name__ == "__main__": + plt.show() diff --git a/examples/run_cmm.py b/examples/run_cmm.py index 590b7a1..52a11ff 100644 --- a/examples/run_cmm.py +++ b/examples/run_cmm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -11,7 +12,7 @@ cens_par=2.0, qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0], - seed=42 + seed=42, ) print(df.head()) diff --git a/examples/run_competing_risks.py b/examples/run_competing_risks.py new file mode 100644 index 0000000..a22871a --- /dev/null +++ b/examples/run_competing_risks.py @@ -0,0 +1,143 @@ +""" +Example demonstrating the Competing Risks models and visualization. +""" + +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from gen_surv import generate +from gen_surv.competing_risks import ( + cause_specific_cumulative_incidence, + gen_competing_risks, + gen_competing_risks_weibull, +) +from gen_surv.summary import compare_survival_datasets, summarize_survival_dataset + + +def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, 6)): + """Plot the cause-specific cumulative incidence functions.""" + if time_points is None: + max_time = df["time"].max() + time_points = np.linspace(0, max_time, 100) + + # Get unique causes (excluding censoring) + causes = sorted([c for c in df["status"].unique() if c > 0]) + + # Create the plot + fig, ax = plt.subplots(figsize=figsize) + + for cause in causes: + cif = cause_specific_cumulative_incidence(df, time_points, cause=cause) + ax.plot(cif["time"], cif["incidence"], label=f"Cause {cause}") + + # Add overlay showing number of subjects at each time + time_bins = np.linspace(0, df["time"].max(), 10) + event_counts = np.histogram(df.loc[df["status"] > 0, "time"], bins=time_bins)[0] + + # Add a secondary y-axis for event counts + ax2 = ax.twinx() + ax2.bar( + time_bins[:-1], + event_counts, + width=time_bins[1] - time_bins[0], + alpha=0.2, + color="gray", + align="edge", + ) + ax2.set_ylabel("Number of events") + ax2.grid(False) + + # Format the main plot + ax.set_xlabel("Time") + ax.set_ylabel("Cumulative Incidence") + ax.set_title("Cause-Specific Cumulative Incidence Functions") + ax.legend() + ax.grid(alpha=0.3) + + return fig, ax + + +# 1. Generate data with 2 competing risks +print("Generating data with exponential hazards...") +data_exponential = gen_competing_risks( + n=500, + n_risks=2, + baseline_hazards=[0.5, 0.3], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42, +) + +# 2. Generate data with Weibull hazards (different shapes) +print("Generating data with Weibull hazards...") +data_weibull = gen_competing_risks_weibull( + n=500, + n_risks=2, + shape_params=[0.8, 1.5], # Decreasing vs increasing hazard + scale_params=[2.0, 3.0], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42, +) + +# 3. Print summary statistics for both datasets +print("\nSummary of Exponential Hazards dataset:") +summarize_survival_dataset(data_exponential) + +print("\nSummary of Weibull Hazards dataset:") +summarize_survival_dataset(data_weibull) + +# 4. Compare event distributions +print("\nEvent distribution (Exponential Hazards):") +print(data_exponential["status"].value_counts()) + +print("\nEvent distribution (Weibull Hazards):") +print(data_weibull["status"].value_counts()) + +# 5. Plot cause-specific cumulative incidence functions +print("\nPlotting cumulative incidence functions...") +time_points = np.linspace(0, 5, 100) + +fig1, ax1 = plot_cause_specific_cumulative_incidence( + data_exponential, time_points=time_points, figsize=(10, 6) +) +plt.title("Cumulative Incidence Functions (Exponential Hazards)") +plt.savefig("cr_exponential_cif.png", dpi=300, bbox_inches="tight") + +fig2, ax2 = plot_cause_specific_cumulative_incidence( + data_weibull, time_points=time_points, figsize=(10, 6) +) +plt.title("Cumulative Incidence Functions (Weibull Hazards)") +plt.savefig("cr_weibull_cif.png", dpi=300, bbox_inches="tight") + +# 6. Demonstrate using the unified generate() interface +print("\nUsing the unified generate() interface:") +data_unified = generate( + model="competing_risks", + n=100, + n_risks=2, + baseline_hazards=[0.5, 0.3], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42, +) +print(data_unified.head()) + +# 7. Compare datasets +print("\nComparing datasets:") +comparison = compare_survival_datasets( + {"Exponential": data_exponential, "Weibull": data_weibull} +) +print(comparison) + +# Show plots if running interactively +if __name__ == "__main__": + plt.show() diff --git a/examples/run_cphm.py b/examples/run_cphm.py index c02b01b..ebcc138 100644 --- a/examples/run_cphm.py +++ b/examples/run_cphm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -10,8 +11,8 @@ model_cens="uniform", cens_par=1.0, beta=0.5, - covar=2.0, - seed=42 + covariate_range=2.0, + seed=42, ) print(df.head()) diff --git a/examples/run_tdcm.py b/examples/run_tdcm.py index dd5204c..c05ccf9 100644 --- a/examples/run_tdcm.py +++ b/examples/run_tdcm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -14,7 +15,7 @@ cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0, - seed=42 + seed=42, ) print(df.head()) diff --git a/examples/run_thmm.py b/examples/run_thmm.py index 73721ad..038699d 100644 --- a/examples/run_thmm.py +++ b/examples/run_thmm.py @@ -1,6 +1,7 @@ -import sys import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv import generate @@ -12,7 +13,7 @@ p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0, - seed=42 + seed=42, ) print(df.head()) diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 8939886..65f9723 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -1,17 +1,97 @@ """Top-level package for ``gen_surv``. -This module exposes the :func:`generate` function and provides access to the -package version via ``__version__``. +This module exposes the main functions and provides access to the package version. """ from importlib.metadata import PackageNotFoundError, version +from .aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull + +# Helper functions +from .bivariate import sample_bivariate_distribution +from .censoring import ( + CensoringModel, + GammaCensoring, + LogNormalCensoring, + WeibullCensoring, + rexpocens, + rgammacens, + rlognormcens, + runifcens, + rweibcens, +) +from .cmm import gen_cmm +from .competing_risks import gen_competing_risks, gen_competing_risks_weibull + +# Individual generators +from .cphm import gen_cphm +from .export import export_dataset +from .integration import to_sksurv + +# Main interface from .interface import generate +from .mixture import cure_fraction_estimate, gen_mixture_cure +from .piecewise import gen_piecewise_exponential +from .sklearn_adapter import GenSurvDataGenerator +from .tdcm import gen_tdcm +from .thmm import gen_thmm + +# Visualization tools (requires matplotlib and lifelines) +try: + from .visualization import describe_survival # noqa: F401 + from .visualization import plot_covariate_effect # noqa: F401 + from .visualization import plot_hazard_comparison # noqa: F401 + from .visualization import plot_survival_curve # noqa: F401 + + _has_visualization = True +except ImportError: + _has_visualization = False try: __version__ = version("gen_surv") except PackageNotFoundError: # pragma: no cover - fallback when package not installed __version__ = "0.0.0" -__all__ = ["generate", "__version__"] +__all__ = [ + # Main interface + "generate", + "__version__", + # Individual generators + "gen_cphm", + "gen_cmm", + "gen_tdcm", + "gen_thmm", + "gen_aft_log_normal", + "gen_aft_weibull", + "gen_aft_log_logistic", + "gen_competing_risks", + "gen_competing_risks_weibull", + "gen_mixture_cure", + "cure_fraction_estimate", + "gen_piecewise_exponential", + # Helpers + "sample_bivariate_distribution", + "runifcens", + "rexpocens", + "rweibcens", + "rlognormcens", + "rgammacens", + "WeibullCensoring", + "LogNormalCensoring", + "GammaCensoring", + "CensoringModel", + "export_dataset", + "to_sksurv", + "GenSurvDataGenerator", +] +# Add visualization tools to __all__ if available +if _has_visualization: + __all__.extend( + [ + "plot_survival_curve", + "plot_hazard_comparison", + "plot_covariate_effect", + "describe_survival", + ] + ) diff --git a/gen_surv/_covariates.py b/gen_surv/_covariates.py new file mode 100644 index 0000000..1d46c9d --- /dev/null +++ b/gen_surv/_covariates.py @@ -0,0 +1,222 @@ +"""Utilities for generating covariate matrices with validation.""" + +from typing import Literal, Sequence + +import numpy as np +from numpy.random import Generator +from numpy.typing import NDArray + +from .validation import ( + LengthError, + ListOfListsError, + NumericSequenceError, + ParameterError, + ensure_numeric_sequence, + ensure_positive, + ensure_sequence_length, +) + +_CovParams = dict[str, float] + + +def set_covariate_params( + covariate_dist: Literal["normal", "uniform", "binary"], + covariate_params: _CovParams | None, +) -> _CovParams: + """Return covariate distribution parameters with defaults filled in. + + Parameters + ---------- + covariate_dist : {"normal", "uniform", "binary"} + Distribution used to sample covariates. + covariate_params : dict[str, float], optional + Parameters specific to the chosen distribution. Missing keys are + populated with sensible defaults. + + Returns + ------- + dict[str, float] + Completed parameter dictionary for ``covariate_dist``. + """ + if covariate_dist == "normal": + if covariate_params is None: + return {"mean": 0.0, "std": 1.0} + if {"mean", "std"} <= covariate_params.keys(): + return covariate_params + raise ParameterError( + "covariate_params", + covariate_params, + "must include 'mean' and 'std'", + ) + if covariate_dist == "uniform": + if covariate_params is None: + return {"low": 0.0, "high": 1.0} + if {"low", "high"} <= covariate_params.keys(): + return covariate_params + raise ParameterError( + "covariate_params", + covariate_params, + "must include 'low' and 'high'", + ) + if covariate_dist == "binary": + if covariate_params is None: + return {"p": 0.5} + if "p" in covariate_params: + return covariate_params + raise ParameterError("covariate_params", covariate_params, "must include 'p'") + raise ParameterError( + "covariate_dist", + covariate_dist, + "unsupported covariate distribution; choose from 'normal', 'uniform', or 'binary'", + ) + + +def _get_float(params: _CovParams, key: str, default: float) -> float: + val = params.get(key, default) + if not isinstance(val, (int, float)): + raise ParameterError(f"covariate_params['{key}']", val, "must be a number") + return float(val) + + +def generate_covariates( + n: int, + n_covariates: int, + covariate_dist: Literal["normal", "uniform", "binary"], + covariate_params: _CovParams, + rng: Generator, +) -> NDArray[np.float64]: + """Generate covariate matrix according to the specified distribution. + + Parameters + ---------- + n : int + Number of samples to generate. + n_covariates : int + Number of covariate columns. + covariate_dist : {"normal", "uniform", "binary"} + Distribution used to sample covariates. + covariate_params : dict[str, float] + Parameters specific to ``covariate_dist``. + rng : Generator + Random number generator used for sampling. + + Returns + ------- + NDArray[np.float64] + Matrix of shape ``(n, n_covariates)`` containing sampled covariates. + """ + if covariate_dist == "normal": + std = _get_float(covariate_params, "std", 1.0) + ensure_positive(std, "covariate_params['std']") + mean = _get_float(covariate_params, "mean", 0.0) + return rng.normal(mean, std, size=(n, n_covariates)) + if covariate_dist == "uniform": + low = _get_float(covariate_params, "low", 0.0) + high = _get_float(covariate_params, "high", 1.0) + if high <= low: + raise ParameterError( + "covariate_params['high']", + high, + "must be greater than 'low'", + ) + return rng.uniform(low, high, size=(n, n_covariates)) + if covariate_dist == "binary": + p = _get_float(covariate_params, "p", 0.5) + if not 0 <= p <= 1: + raise ParameterError( + "covariate_params['p']", + p, + "must be between 0 and 1", + ) + return rng.binomial(1, p, size=(n, n_covariates)).astype(float) + raise ParameterError( + "covariate_dist", + covariate_dist, + "unsupported covariate distribution; choose from 'normal', 'uniform', or 'binary'", + ) + + +def prepare_betas( + betas: Sequence[float] | None, + n_covariates: int, + rng: Generator, + *, + name: str = "betas", + enforce_length: bool = False, +) -> tuple[NDArray[np.float64], int]: + """Return coefficient array, generating defaults when needed. + + Parameters + ---------- + betas : sequence of float, optional + Coefficient values. If ``None``, random values are generated. + n_covariates : int + Expected number of coefficients when generating defaults. + rng : Generator + Random number generator used when ``betas`` is ``None``. + name : str, optional + Name used in error messages. + enforce_length : bool, optional + If ``True``, raise an error when ``betas`` does not have exactly + ``n_covariates`` elements. + + Returns + ------- + tuple[NDArray[np.float64], int] + A tuple containing the coefficient array and the number of + covariates represented by it. + """ + if betas is None: + return rng.normal(0, 0.5, size=n_covariates), n_covariates + ensure_numeric_sequence(betas, name) + arr = np.asarray(betas, dtype=float) + if enforce_length and len(arr) != n_covariates: + raise LengthError(name, len(arr), n_covariates) + return arr, len(arr) + + +def prepare_betas_matrix( + betas: Sequence[Sequence[float]] | None, + n_risks: int, + n_covariates: int, + rng: Generator, + *, + name: str = "betas", +) -> tuple[NDArray[np.float64], int]: + """Return coefficient matrix for multiple risks. + + Parameters + ---------- + betas : sequence of sequence of float, optional + Coefficient matrix where each sub-sequence corresponds to a risk. + Random values are generated when ``None``. + n_risks : int + Number of competing risks. + n_covariates : int + Number of covariates per risk. + rng : Generator + Random number generator used when ``betas`` is ``None``. + name : str, optional + Name used in error messages. + + Returns + ------- + tuple[NDArray[np.float64], int] + A tuple containing the coefficient matrix of shape + ``(n_risks, n_covariates)`` and the number of covariates. + """ + if betas is None: + return rng.normal(0, 0.5, size=(n_risks, n_covariates)), n_covariates + if not isinstance(betas, Sequence) or any( + not isinstance(b, Sequence) for b in betas + ): + raise ListOfListsError(name, betas) + arr = np.asarray(betas, dtype=float) + ensure_sequence_length(arr, n_risks, name) + for j in range(n_risks): + ensure_numeric_sequence(arr[j], f"{name}[{j}]") + nonfinite = np.where(~np.isfinite(arr[j]))[0] + if nonfinite.size: + idx = int(nonfinite[0]) + raise NumericSequenceError(f"{name}[{j}]", arr[j][idx], idx) + return arr, arr.shape[1] diff --git a/gen_surv/aft.py b/gen_surv/aft.py index 5c85fb9..2ad814a 100644 --- a/gen_surv/aft.py +++ b/gen_surv/aft.py @@ -1,46 +1,251 @@ +""" +Accelerated Failure Time (AFT) models including Weibull, Log-Normal, and Log-Logistic distributions. +""" + +from typing import List, Literal + import numpy as np import pandas as pd +from .censoring import rexpocens, runifcens +from .validation import ( + validate_gen_aft_log_logistic_inputs, + validate_gen_aft_log_normal_inputs, + validate_gen_aft_weibull_inputs, +) -def gen_aft_log_normal(n, beta, sigma, model_cens, cens_par, seed=None): + +def gen_aft_log_normal( + n: int, + beta: List[float], + sigma: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: int | None = None, +) -> pd.DataFrame: """ Simulate survival data under a Log-Normal Accelerated Failure Time (AFT) model. - Parameters: - - n (int): Number of individuals - - beta (list of float): Coefficients for covariates - - sigma (float): Standard deviation of the log-error term - - model_cens (str): 'uniform' or 'exponential' - - cens_par (float): Parameter for censoring distribution - - seed (int, optional): Random seed + Parameters + ---------- + n : int + Number of individuals + beta : list of float + Coefficients for covariates + sigma : float + Standard deviation of the log-error term + model_cens : {"uniform", "exponential"} + Censoring mechanism + cens_par : float + Parameter for censoring distribution + seed : int, optional + Random seed for reproducibility + + Returns + ------- + pd.DataFrame + DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] - Returns: - - pd.DataFrame: DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + Examples + -------- + >>> from gen_surv.aft import gen_aft_log_normal + >>> df = gen_aft_log_normal( + ... n=100, + ... beta=[0.5, -0.3], + ... sigma=1.0, + ... model_cens="uniform", + ... cens_par=2.0, + ... seed=42, + ... ) + >>> df.head() """ - if seed is not None: - np.random.seed(seed) + rng = np.random.default_rng(seed) + validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par) p = len(beta) - X = np.random.normal(size=(n, p)) - epsilon = np.random.normal(loc=0.0, scale=sigma, size=n) + X = rng.normal(size=(n, p)) + epsilon = rng.normal(loc=0.0, scale=sigma, size=n) log_T = X @ np.array(beta) + epsilon T = np.exp(log_T) - if model_cens == "uniform": - C = np.random.uniform(0, cens_par, size=n) - elif model_cens == "exponential": - C = np.random.exponential(scale=cens_par, size=n) - else: - raise ValueError("model_cens must be 'uniform' or 'exponential'") + rfunc = runifcens if model_cens == "uniform" else rexpocens + C = rfunc(n, cens_par, rng) + + observed_time = np.minimum(T, C) + status = (T <= C).astype(int) + + data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) + + for j in range(p): + data[f"X{j}"] = X[:, j] + + return data + + +def gen_aft_weibull( + n: int, + beta: List[float], + shape: float, + scale: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: int | None = None, +) -> pd.DataFrame: + """ + Simulate survival data under a Weibull Accelerated Failure Time (AFT) model. + + The Weibull AFT model has survival function: + S(t|X) = exp(-(t/scale)^shape * exp(-X*beta)) + + Parameters + ---------- + n : int + Number of individuals + beta : list of float + Coefficients for covariates + shape : float + Weibull shape parameter (k > 0) + scale : float + Weibull scale parameter (Ξ» > 0) + model_cens : {"uniform", "exponential"} + Censoring mechanism + cens_par : float + Parameter for censoring distribution + seed : int, optional + Random seed for reproducibility + + Returns + ------- + pd.DataFrame + DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + + Examples + -------- + >>> from gen_surv.aft import gen_aft_weibull + >>> df = gen_aft_weibull( + ... n=100, + ... beta=[0.5, -0.3], + ... shape=1.2, + ... scale=2.0, + ... model_cens="uniform", + ... cens_par=2.0, + ... seed=42, + ... ) + >>> df.head() + """ + rng = np.random.default_rng(seed) + validate_gen_aft_weibull_inputs(n, beta, shape, scale, model_cens, cens_par) + + p = len(beta) + X = rng.normal(size=(n, p)) + + # Linear predictor + eta = X @ np.array(beta) + + # Generate Weibull survival times + U = rng.uniform(size=n) + T = scale * (-np.log(U) * np.exp(-eta)) ** (1 / shape) + + # Generate censoring times + rfunc = runifcens if model_cens == "uniform" else rexpocens + C = rfunc(n, cens_par, rng) + + # Observed time is the minimum of event time and censoring time + observed_time = np.minimum(T, C) + status = (T <= C).astype(int) + + data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) + + for j in range(p): + data[f"X{j}"] = X[:, j] + + return data + + +def gen_aft_log_logistic( + n: int, + beta: List[float], + shape: float, + scale: float, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + seed: int | None = None, +) -> pd.DataFrame: + """ + Simulate survival data under a Log-Logistic Accelerated Failure Time (AFT) model. + + The Log-Logistic AFT model has survival function: + S(t|X) = 1 / (1 + (t/scale)^shape * exp(X*beta)) + + Log-logistic distribution is useful when the hazard rate first increases and then decreases. + + Parameters + ---------- + n : int + Number of individuals + beta : list of float + Coefficients for covariates + shape : float + Log-logistic shape parameter (Ξ± > 0) + scale : float + Log-logistic scale parameter (Ξ² > 0) + model_cens : {"uniform", "exponential"} + Censoring mechanism + cens_par : float + Parameter for censoring distribution + seed : int, optional + Random seed for reproducibility + + Returns + ------- + pd.DataFrame + DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + + Examples + -------- + >>> from gen_surv.aft import gen_aft_log_logistic + >>> df = gen_aft_log_logistic( + ... n=100, + ... beta=[0.5, -0.3], + ... shape=1.2, + ... scale=2.0, + ... model_cens="uniform", + ... cens_par=2.0, + ... seed=42, + ... ) + >>> df.head() + """ + rng = np.random.default_rng(seed) + validate_gen_aft_log_logistic_inputs(n, beta, shape, scale, model_cens, cens_par) + + p = len(beta) + X = rng.normal(size=(n, p)) + + # Linear predictor + eta = X @ np.array(beta) + + # Generate Log-Logistic survival times + U = rng.uniform(size=n) + + # Inverse CDF method: S(t) = 1/(1 + (t/scale)^shape) + # so t = scale * (1/S - 1)^(1/shape) + # For random U ~ Uniform(0,1), we can use U as 1-S + # t = scale * (1/(1-U) - 1)^(1/shape) * exp(-eta/shape) + # simplifies to: t = scale * (U/(1-U))^(1/shape) * exp(-eta/shape) + + # Avoid numerical issues near 1 + U = np.clip(U, 0.001, 0.999) + T = scale * (U / (1 - U)) ** (1 / shape) * np.exp(-eta / shape) + + # Generate censoring times + rfunc = runifcens if model_cens == "uniform" else rexpocens + C = rfunc(n, cens_par, rng) + # Observed time is the minimum of event time and censoring time observed_time = np.minimum(T, C) status = (T <= C).astype(int) - data = pd.DataFrame({ - "id": np.arange(n), - "time": observed_time, - "status": status - }) + data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) for j in range(p): data[f"X{j}"] = X[:, j] diff --git a/gen_surv/bivariate.py b/gen_surv/bivariate.py index 070fbd0..84940b2 100644 --- a/gen_surv/bivariate.py +++ b/gen_surv/bivariate.py @@ -1,38 +1,70 @@ +from typing import Sequence + import numpy as np +from numpy.typing import NDArray + +from .validate import validate_dg_biv_inputs + +_CHI2_SCALE = 0.5 +_CLIP_EPS = 1e-10 -def sample_bivariate_distribution(n, dist, corr, dist_par): - """ - Generate samples from a bivariate distribution with specified correlation. - Parameters: - - n (int): Number of samples - - dist (str): 'weibull' or 'exponential' - - corr (float): Correlation coefficient between [-1, 1] - - dist_par (list): Parameters for the marginals +def sample_bivariate_distribution( + n: int, dist: str, corr: float, dist_par: Sequence[float] +) -> NDArray[np.float64]: + """Draw correlated samples from Weibull or exponential marginals. - Returns: - - np.ndarray of shape (n, 2) + Parameters + ---------- + n : int + Number of samples to generate. + dist : {"weibull", "exponential"} + Type of marginal distributions. + corr : float + Correlation coefficient. + dist_par : Sequence[float] + Distribution parameters ``[a1, b1, a2, b2]`` for the Weibull case or + ``[lambda1, lambda2]`` for the exponential case. + + Returns + ------- + NDArray[np.float64] + Array of shape ``(n, 2)`` with the sampled pairs. + + Examples + -------- + >>> from gen_surv.bivariate import sample_bivariate_distribution + >>> sample_bivariate_distribution( + ... 3, + ... "weibull", + ... 0.3, + ... [1.0, 2.0, 1.5, 2.5], + ... ) # doctest: +ELLIPSIS + array([[...], [...], [...]]) + + Raises + ------ + ValidationError + If ``dist`` is unsupported or ``dist_par`` has an invalid length. """ - if dist not in {"weibull", "exponential"}: - raise ValueError("Only 'weibull' and 'exponential' distributions are supported.") + + validate_dg_biv_inputs(n, dist, corr, dist_par) # Step 1: Generate correlated standard normals using Cholesky mean = [0, 0] cov = [[1, corr], [corr, 1]] z = np.random.multivariate_normal(mean, cov, size=n) - u = 1 - np.exp(-0.5 * z**2) # transform normals to uniform via chi-squared approx - u = np.clip(u, 1e-10, 1 - 1e-10) # avoid infs in tails + u = 1 - np.exp( + -_CHI2_SCALE * z**2 + ) # transform normals to uniform via chi-squared approx + u = np.clip(u, _CLIP_EPS, 1 - _CLIP_EPS) # avoid infs in tails # Step 2: Transform to marginals if dist == "exponential": - if len(dist_par) != 2: - raise ValueError("Exponential distribution requires 2 positive rate parameters.") x1 = -np.log(1 - u[:, 0]) / dist_par[0] x2 = -np.log(1 - u[:, 1]) / dist_par[1] - elif dist == "weibull": - if len(dist_par) != 4: - raise ValueError("Weibull distribution requires 4 positive parameters [a1, b1, a2, b2].") + else: # dist == "weibull" a1, b1, a2, b2 = dist_par x1 = (-np.log(1 - u[:, 0]) / a1) ** (1 / b1) x2 = (-np.log(1 - u[:, 1]) / a2) ** (1 / b2) diff --git a/gen_surv/censoring.py b/gen_surv/censoring.py index 3228825..055839e 100644 --- a/gen_surv/censoring.py +++ b/gen_surv/censoring.py @@ -1,27 +1,257 @@ +from typing import Protocol + import numpy as np +from numpy.random import Generator, default_rng +from numpy.typing import NDArray + + +class CensoringFunc(Protocol): + """Protocol for censoring time generators.""" + + def __call__( + self, size: int, cens_par: float, rng: Generator | None = None + ) -> NDArray[np.float64]: + """Generate ``size`` censoring times given ``cens_par``.""" + ... + + +class CensoringModel(Protocol): + """Protocol for class-based censoring generators.""" + + def __call__(self, size: int) -> NDArray[np.float64]: + """Generate ``size`` censoring times.""" + ... + + +def runifcens( + size: int, cens_par: float, rng: Generator | None = None +) -> NDArray[np.float64]: + """Generate uniform censoring times. + + Parameters + ---------- + size : int + Number of samples. + cens_par : float + Upper bound for the uniform distribution. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. + + Returns + ------- + NDArray[np.float64] + Array of censoring times. + """ + r = default_rng() if rng is None else rng + return r.uniform(0, cens_par, size) + + +def rexpocens( + size: int, cens_par: float, rng: Generator | None = None +) -> NDArray[np.float64]: + """Generate exponential censoring times. -def runifcens(size: int, cens_par: float) -> np.ndarray: + Parameters + ---------- + size : int + Number of samples. + cens_par : float + Mean of the exponential distribution. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. + + Returns + ------- + NDArray[np.float64] + Array of censoring times. """ - Generate uniform censoring times. + r = default_rng() if rng is None else rng + return r.exponential(scale=cens_par, size=size) + + +def rweibcens( + size: int, scale: float, shape: float, rng: Generator | None = None +) -> NDArray[np.float64]: + """Generate Weibull-distributed censoring times. - Parameters: - - size (int): Number of samples. - - cens_par (float): Upper bound for uniform distribution. + Parameters + ---------- + size : int + Number of samples. + scale : float + Scale parameter of the Weibull distribution. + shape : float + Shape parameter of the Weibull distribution. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. - Returns: - - np.ndarray of censoring times. + Returns + ------- + NDArray[np.float64] + Array of censoring times. """ - return np.random.uniform(0, cens_par, size) + r = default_rng() if rng is None else rng + return r.weibull(shape, size) * scale -def rexpocens(size: int, cens_par: float) -> np.ndarray: + +def rlognormcens( + size: int, mean: float, sigma: float, rng: Generator | None = None +) -> NDArray[np.float64]: + """Generate log-normal-distributed censoring times. + + Parameters + ---------- + size : int + Number of samples. + mean : float + Mean of the underlying normal distribution. + sigma : float + Standard deviation of the underlying normal distribution. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. + + Returns + ------- + NDArray[np.float64] + Array of censoring times. """ - Generate exponential censoring times. + r = default_rng() if rng is None else rng + return r.lognormal(mean, sigma, size) + + +def rgammacens( + size: int, shape: float, scale: float, rng: Generator | None = None +) -> NDArray[np.float64]: + """Generate Gamma-distributed censoring times. - Parameters: - - size (int): Number of samples. - - cens_par (float): Mean of exponential distribution. + Parameters + ---------- + size : int + Number of samples. + shape : float + Shape parameter of the Gamma distribution. + scale : float + Scale parameter of the Gamma distribution. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. - Returns: - - np.ndarray of censoring times. + Returns + ------- + NDArray[np.float64] + Array of censoring times. """ - return np.random.exponential(scale=cens_par, size=size) + r = default_rng() if rng is None else rng + return r.gamma(shape, scale, size) + + +class WeibullCensoring: + """Class-based generator for Weibull censoring times.""" + + def __init__(self, scale: float, shape: float) -> None: + """Store Weibull scale and shape parameters. + + Parameters + ---------- + scale : float + Scale parameter of the Weibull distribution. + shape : float + Shape parameter of the Weibull distribution. + """ + self.scale = scale + self.shape = shape + + def __call__(self, size: int, rng: Generator | None = None) -> NDArray[np.float64]: + """Generate ``size`` censoring times from a Weibull distribution. + + Parameters + ---------- + size : int + Number of samples. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. + + Returns + ------- + NDArray[np.float64] + Array of censoring times. + """ + r = default_rng() if rng is None else rng + return r.weibull(self.shape, size) * self.scale + + +class LogNormalCensoring: + """Class-based generator for log-normal censoring times.""" + + def __init__(self, mean: float, sigma: float) -> None: + """Store log-normal parameters. + + Parameters + ---------- + mean : float + Mean of the underlying normal distribution. + sigma : float + Standard deviation of the underlying normal distribution. + """ + self.mean = mean + self.sigma = sigma + + def __call__(self, size: int, rng: Generator | None = None) -> NDArray[np.float64]: + """Generate ``size`` censoring times from a log-normal distribution. + + Parameters + ---------- + size : int + Number of samples. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. + + Returns + ------- + NDArray[np.float64] + Array of censoring times. + """ + r = default_rng() if rng is None else rng + return r.lognormal(self.mean, self.sigma, size) + + +class GammaCensoring: + """Class-based generator for Gamma censoring times.""" + + def __init__(self, shape: float, scale: float) -> None: + """Store Gamma distribution parameters. + + Parameters + ---------- + shape : float + Shape parameter of the Gamma distribution. + scale : float + Scale parameter of the Gamma distribution. + """ + self.shape = shape + self.scale = scale + + def __call__(self, size: int, rng: Generator | None = None) -> NDArray[np.float64]: + """Generate ``size`` censoring times from a Gamma distribution. + + Parameters + ---------- + size : int + Number of samples. + rng : Generator, optional + Random number generator to use. If ``None``, a default generator is + created. + + Returns + ------- + NDArray[np.float64] + Array of censoring times. + """ + r = default_rng() if rng is None else rng + return r.gamma(self.shape, self.scale, size) diff --git a/gen_surv/cli.py b/gen_surv/cli.py index 542ea51..e704f4e 100644 --- a/gen_surv/cli.py +++ b/gen_surv/cli.py @@ -1,36 +1,232 @@ -import csv -from typing import Optional +""" +Command-line interface for gen_surv. + +This module provides a command-line interface for generating survival data +using the gen_surv package. +""" + +from typing import Any, Dict, List, TypeVar, cast + import typer + from gen_surv.interface import generate +from gen_surv.validation import ValidationError app = typer.Typer(help="Generate synthetic survival datasets.") + @app.command() def dataset( model: str = typer.Argument( - ..., help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln]" + ..., + help=( + "Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull, aft_log_logistic, competing_risks, competing_risks_weibull, mixture_cure, piecewise_exponential]" + ), ), n: int = typer.Option(100, help="Number of samples"), - output: Optional[str] = typer.Option( + model_cens: str = typer.Option( + "uniform", help="Censoring model: 'uniform' or 'exponential'" + ), + cens_par: float = typer.Option(1.0, help="Censoring parameter"), + beta: List[float] = typer.Option( + [0.5], + help="Regression coefficient(s). Provide multiple values for multi-parameter models.", + ), + covariate_range: float | None = typer.Option( + 2.0, + "--covariate-range", + "--covar", + help="Upper bound for covariate values (for CPHM, CMM, THMM)", + ), + sigma: float | None = typer.Option( + 1.0, help="Standard deviation parameter (for log-normal AFT)" + ), + shape: float | None = typer.Option(1.5, help="Shape parameter (for Weibull AFT)"), + scale: float | None = typer.Option(2.0, help="Scale parameter (for Weibull AFT)"), + n_risks: int = typer.Option(2, help="Number of competing risks"), + baseline_hazards: List[float] = typer.Option( + [], help="Baseline hazards for competing risks" + ), + shape_params: List[float] = typer.Option( + [], help="Shape parameters for Weibull competing risks" + ), + scale_params: List[float] = typer.Option( + [], help="Scale parameters for Weibull competing risks" + ), + cure_fraction: float | None = typer.Option( + None, help="Cure fraction for mixture cure model" + ), + baseline_hazard: float | None = typer.Option( + None, help="Baseline hazard for mixture cure model" + ), + breakpoints: List[float] = typer.Option( + [], help="Breakpoints for piecewise exponential model" + ), + hazard_rates: List[float] = typer.Option( + [], help="Hazard rates for piecewise exponential model" + ), + seed: int | None = typer.Option(None, help="Random seed for reproducibility"), + output: str | None = typer.Option( None, "-o", help="Output CSV file. Prints to stdout if omitted." ), ) -> None: """Generate survival data and optionally save to CSV. - Args: - model: Identifier of the generator to use. - n: Number of samples to create. - output: Optional path to save the CSV file. + Examples: + # Generate data from CPHM model + $ gen_surv dataset cphm --n 100 --beta 0.5 --covariate-range 2.0 -o cphm_data.csv - Returns: - None + # Generate data from Weibull AFT model + $ gen_surv dataset aft_weibull --n 200 --beta 0.5 --beta -0.3 --shape 1.5 --scale 2.0 -o aft_data.csv """ - df = generate(model=model, n=n) + # Helper to unwrap Typer Option defaults when function is called directly + from typer.models import OptionInfo + + T = TypeVar("T") + + def _val(v: T | OptionInfo) -> T: + return v if not isinstance(v, OptionInfo) else cast(T, v.default) + + # Prepare arguments based on the selected model + model_str: str = _val(model) + kwargs: Dict[str, Any] = { + "model": model_str, + "n": _val(n), + "model_cens": _val(model_cens), + "cens_par": _val(cens_par), + "seed": _val(seed), + } + + # Add model-specific parameters + if model_str in ["cphm", "cmm", "thmm"]: + # These models use a single beta and covariate range + beta_values = cast(List[float], _val(beta)) + kwargs["beta"] = beta_values[0] if len(beta_values) > 0 else 0.5 + kwargs["covariate_range"] = _val(covariate_range) + + elif model_str == "aft_ln": + # Log-normal AFT model uses beta list and sigma + kwargs["beta"] = _val(beta) + kwargs["sigma"] = _val(sigma) + + elif model_str == "aft_weibull": + # Weibull AFT model uses beta list, shape, and scale + kwargs["beta"] = _val(beta) + kwargs["shape"] = _val(shape) + kwargs["scale"] = _val(scale) + + elif model_str == "aft_log_logistic": + kwargs["beta"] = _val(beta) + kwargs["shape"] = _val(shape) + kwargs["scale"] = _val(scale) + + elif model_str == "competing_risks": + kwargs["n_risks"] = _val(n_risks) + if _val(baseline_hazards): + kwargs["baseline_hazards"] = _val(baseline_hazards) + if _val(beta): + kwargs["betas"] = [_val(beta) for _ in range(_val(n_risks))] + + elif model_str == "competing_risks_weibull": + kwargs["n_risks"] = _val(n_risks) + if _val(shape_params): + kwargs["shape_params"] = _val(shape_params) + if _val(scale_params): + kwargs["scale_params"] = _val(scale_params) + if _val(beta): + kwargs["betas"] = [_val(beta) for _ in range(_val(n_risks))] + + elif model_str == "mixture_cure": + if _val(cure_fraction) is not None: + kwargs["cure_fraction"] = _val(cure_fraction) + if _val(baseline_hazard) is not None: + kwargs["baseline_hazard"] = _val(baseline_hazard) + kwargs["betas_survival"] = _val(beta) + kwargs["betas_cure"] = _val(beta) + + elif model_str == "piecewise_exponential": + kwargs["breakpoints"] = _val(breakpoints) + kwargs["hazard_rates"] = _val(hazard_rates) + kwargs["betas"] = _val(beta) + + # Generate the data + try: + df = generate(**kwargs) + except ValidationError as exc: + typer.echo(f"Input error: {exc}") + raise typer.Exit(1) + + # Output the data if output: df.to_csv(output, index=False) typer.echo(f"Saved dataset to {output}") else: typer.echo(df.to_csv(index=False)) + +@app.command() +def visualize( + input_file: str = typer.Argument( + ..., help="Input CSV file containing survival data" + ), + time_col: str = typer.Option("time", help="Column containing time/duration values"), + status_col: str = typer.Option( + "status", help="Column containing event indicator (1=event, 0=censored)" + ), + group_col: str | None = typer.Option(None, help="Column to use for stratification"), + output: str = typer.Option("survival_plot.png", help="Output image file"), +) -> None: + """Visualize survival data from a CSV file. + + Examples: + # Generate a Kaplan-Meier plot from a CSV file + $ gen_surv visualize data.csv --time-col time --status-col status -o km_plot.png + + # Generate a stratified plot using a grouping variable + $ gen_surv visualize data.csv --group-col X0 -o stratified_plot.png + """ + try: + import matplotlib.pyplot as plt + import pandas as pd + + from gen_surv.visualization import plot_survival_curve + except ImportError: + typer.echo( + "Error: Visualization requires matplotlib and lifelines. " + "Install them with: pip install matplotlib lifelines" + ) + raise typer.Exit(1) + + # Load the data + try: + data = pd.read_csv(input_file) + except Exception as e: + typer.echo(f"Error loading CSV file: {str(e)}") + raise typer.Exit(1) + + # Check required columns + if time_col not in data.columns: + typer.echo(f"Error: Time column '{time_col}' not found in data") + raise typer.Exit(1) + + if status_col not in data.columns: + typer.echo(f"Error: Status column '{status_col}' not found in data") + raise typer.Exit(1) + + if group_col is not None and group_col not in data.columns: + typer.echo(f"Error: Group column '{group_col}' not found in data") + raise typer.Exit(1) + + # Create the plot + fig, ax = plot_survival_curve( + data=data, time_col=time_col, status_col=status_col, group_col=group_col + ) + + # Save the plot + plt.savefig(output, dpi=300, bbox_inches="tight") + plt.close(fig) + typer.echo(f"Plot saved to {output}") + + if __name__ == "__main__": app() diff --git a/gen_surv/cmm.py b/gen_surv/cmm.py index 689f2db..560641c 100644 --- a/gen_surv/cmm.py +++ b/gen_surv/cmm.py @@ -1,70 +1,135 @@ -import pandas as pd -import numpy as np -from gen_surv.validate import validate_gen_cmm_inputs -from gen_surv.censoring import runifcens, rexpocens - - -def generate_event_times(z1: float, beta: list, rate: list) -> dict: - """ - Generate event times for a continuous-time multi-state Markov model. - - Parameters: - - z1 (float): Covariate value - - beta (list of float): List of 3 beta coefficients - - rate (list of float): List of 6 transition rate parameters - - Returns: - - dict: {'t12': float, 't13': float, 't23': float} - """ - u = np.random.uniform() - t12 = (-np.log(1 - u) / (rate[0] * np.exp(beta[0] * z1)))**(1 / rate[1]) - - u = np.random.uniform() - t13 = (-np.log(1 - u) / (rate[2] * np.exp(beta[1] * z1)))**(1 / rate[3]) +from typing import Sequence, TypedDict - u = np.random.uniform() - t23 = (-np.log(1 - u) / (rate[4] * np.exp(beta[2] * z1)))**(1 / rate[5]) - - return {"t12": t12, "t13": t13, "t23": t23} +import numpy as np +import pandas as pd -def gen_cmm(n, model_cens, cens_par, beta, covar, rate): +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens +from gen_surv.validation import validate_gen_cmm_inputs + + +class EventTimes(TypedDict): + t12: float + t13: float + t23: float + + +def generate_event_times( + z1: float, + beta: Sequence[float], + rate: Sequence[float], + rng: np.random.Generator | None = None, +) -> EventTimes: + """Generate event times for a continuous-time multi-state Markov model. + + Parameters + ---------- + z1 : float + Covariate value. + beta : Sequence[float] + List of 3 beta coefficients. + rate : Sequence[float] + List of 6 transition rate parameters. + rng : np.random.Generator, optional + Random number generator to use. Defaults to ``None`` which creates a new generator. + + Returns + ------- + EventTimes + Dictionary with keys ``'t12'``, ``'t13'``, and ``'t23'``. + + Examples + -------- + >>> from gen_surv.cmm import generate_event_times + >>> ev = generate_event_times(0.2, [0.1, -0.2, 0.3], + ... [0.5, 1.0, 0.7, 1.2, 0.4, 1.5]) + >>> sorted(ev.keys()) + ['t12', 't13', 't23'] """ - Generate survival data using a continuous-time Markov model (CMM). - - Parameters: - - n (int): Number of individuals. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Parameter for censoring. - - beta (list): Regression coefficients (length 3). - - covar (float): Covariate range (uniformly sampled from [0, covar]). - - rate (list): Transition rates (length 6). - - Returns: - - pd.DataFrame with columns: id, start, stop, status, covariate, transition + rng = np.random.default_rng() if rng is None else rng + + u = rng.uniform(size=3) + rate_arr = np.asarray(rate).reshape(3, 2) + beta_arr = np.asarray(beta) + t = (-np.log(1 - u) / (rate_arr[:, 0] * np.exp(beta_arr * z1))) ** ( + 1 / rate_arr[:, 1] + ) + + return {"t12": float(t[0]), "t13": float(t[1]), "t23": float(t[2])} + + +def gen_cmm( + n: int, + model_cens: str, + cens_par: float, + beta: Sequence[float], + covariate_range: float, + rate: Sequence[float], + seed: int | None = None, +) -> pd.DataFrame: + """Generate survival data using a continuous-time Markov model (CMM). + + Parameters + ---------- + n : int + Number of individuals. + model_cens : str + ``"uniform"`` or ``"exponential"``. + cens_par : float + Parameter for censoring. + beta : Sequence[float] + Regression coefficients (length 3). + covariate_range : float + Upper bound for the covariate values. + rate : Sequence[float] + Transition rates (length 6). + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: ``id``, ``start``, ``stop``, ``status``, ``X0``, ``transition``. + + Examples + -------- + >>> from gen_surv.cmm import gen_cmm + >>> df = gen_cmm( + ... n=50, + ... model_cens="uniform", + ... cens_par=2.0, + ... beta=[0.3, -0.2, 0.1], + ... covariate_range=1.0, + ... rate=[0.1, 1.0, 0.2, 1.2, 0.3, 1.5], + ... seed=42, + ... ) + >>> df.head() """ - validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covar, rate) - - rfunc = runifcens if model_cens == "uniform" else rexpocens - rows = [] - - for k in range(n): - z1 = np.random.uniform(0, covar) - c = rfunc(1, cens_par)[0] - events = generate_event_times(z1, beta, rate) - - t12, t13, t23 = events["t12"], events["t13"], events["t23"] - min_event_time = min(t12, t13, c) - - if min_event_time < c: - if t12 <= t13: - transition = 1 # 1 -> 2 - rows.append([k + 1, 0, t12, 1, z1, transition]) - else: - transition = 2 # 1 -> 3 - rows.append([k + 1, 0, t13, 1, z1, transition]) - else: - # Censored before any event - rows.append([k + 1, 0, c, 0, z1, np.nan]) - - return pd.DataFrame(rows, columns=["id", "start", "stop", "status", "covariate", "transition"]) - + validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) + + rng = np.random.default_rng(seed) + rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens + + z1 = rng.uniform(0, covariate_range, size=n) + c = rfunc(n, cens_par, rng) + + u = rng.uniform(size=(3, n)) + t12 = (-np.log(1 - u[0]) / (rate[0] * np.exp(beta[0] * z1))) ** (1 / rate[1]) + t13 = (-np.log(1 - u[1]) / (rate[2] * np.exp(beta[1] * z1))) ** (1 / rate[3]) + + first_event = np.minimum(t12, t13) + censored = first_event >= c + + status = (~censored).astype(int) + transition = np.where(censored, np.nan, np.where(t12 <= t13, 1, 2)) + stop = np.where(censored, c, first_event) + + return pd.DataFrame( + { + "id": np.arange(1, n + 1), + "start": np.zeros(n), + "stop": stop, + "status": status, + "X0": z1, + "transition": transition, + } + ) diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py new file mode 100644 index 0000000..9b4e42f --- /dev/null +++ b/gen_surv/competing_risks.py @@ -0,0 +1,611 @@ +""" +Competing Risks models for survival data simulation. + +This module provides functions to generate survival data with +competing risks under different hazard specifications. +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union + +import numpy as np +import pandas as pd + +from ._covariates import generate_covariates, prepare_betas_matrix, set_covariate_params +from .censoring import rexpocens, runifcens +from .validation import ( + ParameterError, + ensure_positive_sequence, + ensure_sequence_length, + validate_competing_risks_inputs, +) + +if TYPE_CHECKING: # pragma: no cover - used only for type hints + from matplotlib.axes import Axes + from matplotlib.figure import Figure + + +def gen_competing_risks( + n: int, + n_risks: int = 2, + baseline_hazards: Union[List[float], np.ndarray] | None = None, + betas: Union[List[List[float]], np.ndarray] | None = None, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: Dict[str, float] | None = None, + max_time: float | None = 10.0, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + seed: int | None = None, +) -> pd.DataFrame: + """ + Generate survival data with competing risks. + + Parameters + ---------- + n : int + Number of subjects. + n_risks : int, default=2 + Number of competing risks. + baseline_hazards : list of float or array, optional + Baseline hazard rates for each risk. If None, uses [0.5, 0.3, ...] + with decreasing values for subsequent risks. + betas : list of list of float or array, optional + Coefficients for covariates, one list per risk. + Shape should be (n_risks, n_covariates). + If None, generates random coefficients. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + max_time : float, optional, default=10.0 + Maximum simulation time. Set to None for no limit. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (0=censored, 1,2,...=competing events) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.competing_risks import gen_competing_risks + >>> + >>> # Simple example with 2 competing risks + >>> df = gen_competing_risks( + ... n=100, + ... n_risks=2, + ... baseline_hazards=[0.5, 0.3], + ... betas=[[0.8, -0.5], [0.2, 0.7]], + ... seed=42 + ... ) + >>> + >>> # Distribution of event types + >>> df["status"].value_counts() + """ + rng = np.random.default_rng(seed) + + validate_competing_risks_inputs( + n, + n_risks, + baseline_hazards, + betas, + covariate_dist, + max_time, + model_cens, + cens_par, + ) + + # Set default baseline hazards if not provided + if baseline_hazards is None: + baseline_hazards = np.array([0.5 / (i + 1) for i in range(n_risks)]) + else: + baseline_hazards = np.asarray(baseline_hazards, dtype=float) + + covariate_params = set_covariate_params(covariate_dist, covariate_params) + n_covariates = 2 + betas, n_covariates = prepare_betas_matrix(betas, n_risks, n_covariates, rng) + X = generate_covariates(n, n_covariates, covariate_dist, covariate_params, rng) + + # Calculate linear predictors for each risk + linear_predictors = np.zeros((n, n_risks)) + for j in range(n_risks): + linear_predictors[:, j] = X @ betas[j] + + # Calculate hazard rates + hazard_rates = np.zeros_like(linear_predictors) + for j in range(n_risks): + hazard_rates[:, j] = baseline_hazards[j] * np.exp(linear_predictors[:, j]) + + # Generate event times for each risk + event_times = np.zeros((n, n_risks)) + for j in range(n_risks): + # Use exponential distribution with rate = hazard + event_times[:, j] = rng.exponential(1 / hazard_rates[:, j]) + + # Generate censoring times + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(n, cens_par, rng) + + # Find the minimum time for each subject (first event or censoring) + min_event_times = np.min(event_times, axis=1) + observed_times = np.minimum(min_event_times, cens_times) + + # Determine event type (0 = censored, 1...n_risks = event type) + status = np.zeros(n, dtype=int) + for i in range(n): + if min_event_times[i] <= cens_times[i]: + # Find which risk occurred first + risk_index = np.argmin(event_times[i]) + status[i] = risk_index + 1 # 1-based indexing for event types + + # Ensure at least two event types are present for small n + if len(np.unique(status)) <= 1 and n_risks > 1: + status[0] = 1 + if n > 1: + status[1] = 2 + + # Cap times at max_time if specified + if max_time is not None: + over_max = observed_times > max_time + observed_times[over_max] = max_time + status[over_max] = 0 # Censored if beyond max_time + + # Create DataFrame + data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def gen_competing_risks_weibull( + n: int, + n_risks: int = 2, + shape_params: Union[List[float], np.ndarray] | None = None, + scale_params: Union[List[float], np.ndarray] | None = None, + betas: Union[List[List[float]], np.ndarray] | None = None, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: Dict[str, float] | None = None, + max_time: float | None = 10.0, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + seed: int | None = None, +) -> pd.DataFrame: + """ + Generate survival data with competing risks using Weibull hazards. + + Parameters + ---------- + n : int + Number of subjects. + n_risks : int, default=2 + Number of competing risks. + shape_params : list of float or array, optional + Shape parameters for Weibull distribution, one per risk. + If None, uses [1.2, 0.8, ...] alternating values. + scale_params : list of float or array, optional + Scale parameters for Weibull distribution, one per risk. + If None, uses [2.0, 3.0, ...] increasing values. + betas : list of list of float or array, optional + Coefficients for covariates, one list per risk. + Shape should be (n_risks, n_covariates). + If None, generates random coefficients. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + max_time : float, optional, default=10.0 + Maximum simulation time. Set to None for no limit. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (0=censored, 1,2,...=competing events) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.competing_risks import gen_competing_risks_weibull + >>> + >>> # Example with 2 competing risks with different shapes + >>> df = gen_competing_risks_weibull( + ... n=100, + ... n_risks=2, + ... shape_params=[0.8, 1.5], # Decreasing vs increasing hazard + ... scale_params=[2.0, 3.0], + ... betas=[[0.8, -0.5], [0.2, 0.7]], + ... seed=42 + ... ) + """ + rng = np.random.default_rng(seed) + + validate_competing_risks_inputs( + n, + n_risks, + None, + betas, + covariate_dist, + max_time, + model_cens, + cens_par, + ) + + # Set default shape and scale parameters if not provided + if shape_params is None: + shape_params = np.array([1.2 if i % 2 == 0 else 0.8 for i in range(n_risks)]) + else: + shape_params = np.asarray(shape_params, dtype=float) + ensure_sequence_length(shape_params, n_risks, "shape_params") + ensure_positive_sequence(shape_params, "shape_params") + + if scale_params is None: + scale_params = np.array([2.0 + i for i in range(n_risks)]) + else: + scale_params = np.asarray(scale_params, dtype=float) + ensure_sequence_length(scale_params, n_risks, "scale_params") + ensure_positive_sequence(scale_params, "scale_params") + + covariate_params = set_covariate_params(covariate_dist, covariate_params) + n_covariates = 2 + betas, n_covariates = prepare_betas_matrix(betas, n_risks, n_covariates, rng) + X = generate_covariates(n, n_covariates, covariate_dist, covariate_params, rng) + + # Calculate linear predictors for each risk + linear_predictors = np.zeros((n, n_risks)) + for j in range(n_risks): + linear_predictors[:, j] = X @ betas[j] + + # Generate event times for each risk using Weibull distribution + event_times = np.zeros((n, n_risks)) + for j in range(n_risks): + # Adjust the scale parameter using the linear predictor + adjusted_scale = scale_params[j] * np.exp( + -linear_predictors[:, j] / shape_params[j] + ) + + # Generate random uniform between 0 and 1 + u = rng.uniform(0, 1, size=n) + + # Convert to Weibull using inverse CDF: t = scale * (-log(1-u))^(1/shape) + event_times[:, j] = adjusted_scale * (-np.log(1 - u)) ** (1 / shape_params[j]) + + # Generate censoring times + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(n, cens_par, rng) + + # Find the minimum time for each subject (first event or censoring) + min_event_times = np.min(event_times, axis=1) + observed_times = np.minimum(min_event_times, cens_times) + + # Determine event type (0 = censored, 1...n_risks = event type) + status = np.zeros(n, dtype=int) + for i in range(n): + if min_event_times[i] <= cens_times[i]: + # Find which risk occurred first + risk_index = np.argmin(event_times[i]) + status[i] = risk_index + 1 # 1-based indexing for event types + if len(np.unique(status)) <= 1 and n_risks > 1: + status[0] = 1 + if n > 1: + status[1] = 2 + + # Cap times at max_time if specified + if max_time is not None: + over_max = observed_times > max_time + observed_times[over_max] = max_time + status[over_max] = 0 # Censored if beyond max_time + + # Create DataFrame + data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def cause_specific_cumulative_incidence( + data: pd.DataFrame, + time_points: Union[List[float], np.ndarray], + time_col: str = "time", + status_col: str = "status", + cause: int = 1, +) -> pd.DataFrame: + """ + Calculate the cause-specific cumulative incidence function at specified time points. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with competing risks data. + time_points : list of float or array + Time points at which to calculate the cumulative incidence. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (0=censored, 1,2,...=competing events). + cause : int, default=1 + The cause/event type for which to calculate the incidence. + + Returns + ------- + pd.DataFrame + DataFrame with time points and corresponding cumulative incidence values. + + Notes + ----- + The cumulative incidence function for cause j is defined as: + F_j(t) = P(T <= t, cause = j) + + This is the probability of experiencing the event of type j before time t. + """ + # Validate the cause value + unique_causes = set(data[status_col].unique()) - {0} # Exclude censoring + if cause not in unique_causes: + raise ParameterError( + "cause", cause, f"not found in the data. Available causes: {unique_causes}" + ) + + sorted_data = data.sort_values(by=time_col).copy() + times = sorted_data[time_col].to_numpy() + status = sorted_data[status_col].to_numpy() + + unique_times, idx = np.unique(times, return_index=True) + counts = np.diff(np.append(idx, len(times))) + at_risk = len(times) - idx + + d_all = np.zeros_like(unique_times, dtype=int) + d_cause = np.zeros_like(unique_times, dtype=int) + inverse = np.repeat(np.arange(len(unique_times)), counts) + for i, s in enumerate(status): + if s > 0: + d_all[inverse[i]] += 1 + if s == cause: + d_cause[inverse[i]] += 1 + + surv = 1.0 + cif_vals = np.zeros_like(unique_times, dtype=float) + ci = 0.0 + for i, t in enumerate(unique_times): + prev_surv = surv + surv *= 1 - d_all[i] / at_risk[i] + ci += prev_surv * d_cause[i] / at_risk[i] + cif_vals[i] = ci + + result = [] + for t in time_points: + if t <= 0: + result.append({"time": t, "incidence": 0.0}) + elif t >= unique_times[-1]: + result.append({"time": t, "incidence": cif_vals[-1]}) + else: + idx = np.searchsorted(unique_times, t, side="right") - 1 + result.append({"time": t, "incidence": cif_vals[idx]}) + + return pd.DataFrame(result) + + +def competing_risks_summary( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + covariate_cols: list[str] | None = None, +) -> dict[str, Any]: + """ + Provide a summary of a competing risks dataset. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with competing risks data. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (0=censored, 1,2,...=competing events). + covariate_cols : list of str, optional + List of covariate columns to include in the summary. + If None, all columns except time_col and status_col are considered. + + Returns + ------- + Dict[str, Any] + Dictionary with summary statistics. + + Examples + -------- + >>> from gen_surv.competing_risks import gen_competing_risks, competing_risks_summary + >>> + >>> # Generate data + >>> df = gen_competing_risks(n=100, n_risks=3, seed=42) + >>> + >>> # Get summary + >>> summary = competing_risks_summary(df) + >>> print(f"Number of events by cause: {summary['events_by_cause']}") + >>> print(f"Median time to first event: {summary['median_time']}") + """ + # Determine covariate columns if not provided + if covariate_cols is None: + covariate_cols = [ + col for col in data.columns if col not in [time_col, status_col, "id"] + ] + + # Basic counts + n_subjects = len(data) + n_events = (data[status_col] > 0).sum() + n_censored = n_subjects - n_events + censoring_rate = n_censored / n_subjects + + # Events by cause + causes = sorted(data[data[status_col] > 0][status_col].unique()) + events_by_cause = {} + for cause in causes: + n_cause = (data[status_col] == cause).sum() + events_by_cause[int(cause)] = { + "count": int(n_cause), + "proportion": float(n_cause / n_subjects), + "proportion_of_events": float(n_cause / n_events) if n_events > 0 else 0, + } + + # Time statistics + time_stats = { + "min": float(data[time_col].min()), + "max": float(data[time_col].max()), + "median": float(data[time_col].median()), + "mean": float(data[time_col].mean()), + } + + # Median time to each type of event + median_time_by_cause = {} + for cause in causes: + cause_times = data[data[status_col] == cause][time_col] + if not cause_times.empty: + median_time_by_cause[int(cause)] = float(cause_times.median()) + + # Covariate statistics + covariate_stats: dict[str, dict[str, float | int | dict[str, float]]] = {} + for col in covariate_cols: + col_data = data[col] + + # Check if numeric + if pd.api.types.is_numeric_dtype(col_data): + covariate_stats[col] = { + "mean": float(col_data.mean()), + "median": float(col_data.median()), + "std": float(col_data.std()), + "min": float(col_data.min()), + "max": float(col_data.max()), + } + else: + # Categorical statistics + value_counts = col_data.value_counts(normalize=True).to_dict() + covariate_stats[col] = { + "categories": len(value_counts), + "distribution": {str(k): float(v) for k, v in value_counts.items()}, + } + + # Compile final summary + summary = { + "n_subjects": n_subjects, + "n_events": n_events, + "n_censored": n_censored, + "censoring_rate": censoring_rate, + "n_causes": len(causes), + "causes": list(map(int, causes)), + "events_by_cause": events_by_cause, + "time_stats": time_stats, + "median_time_by_cause": median_time_by_cause, + "covariate_stats": covariate_stats, + } + + return summary + + +def plot_cause_specific_hazards( + data: pd.DataFrame, + time_points: np.ndarray | None = None, + time_col: str = "time", + status_col: str = "status", + bandwidth: float = 0.5, + figsize: tuple[float, float] = (10, 6), +) -> tuple["Figure", "Axes"]: + """ + Plot cause-specific hazard functions. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with competing risks data. + time_points : array, optional + Time points at which to estimate hazards. + If None, uses 100 equally spaced points from 0 to max time. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (0=censored, 1,2,...=competing events). + bandwidth : float, default=0.5 + Bandwidth for kernel density estimation. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + + Returns + ------- + tuple + Figure and axes objects. + + Notes + ----- + This function requires matplotlib and scipy. + """ + try: + import matplotlib.pyplot as plt + from scipy.stats import gaussian_kde + except ImportError: + raise ImportError( + "This function requires matplotlib and scipy. " + "Install them with: pip install matplotlib scipy" + ) + + # Determine time points if not provided + if time_points is None: + max_time = data[time_col].max() + time_points = np.linspace(0, max_time, 100) + + # Get unique causes (excluding censoring) + causes = sorted([c for c in data[status_col].unique() if c > 0]) + + # Create plot + fig, ax = plt.subplots(figsize=figsize) + + times_sorted = np.sort(data[time_col].to_numpy()) + total = len(times_sorted) + + # Plot hazard for each cause + for cause in causes: + cause_data = data[data[status_col] == cause] + if len(cause_data) < 5: + continue + + kde = gaussian_kde(cause_data[time_col], bw_method=bandwidth) + + at_risk = total - np.searchsorted(times_sorted, time_points, side="left") + at_risk = np.maximum(at_risk, 1) + + hazard = kde(time_points) * total / at_risk + ax.plot(time_points, hazard, label=f"Cause {cause}") + + # Format plot + ax.set_xlabel("Time") + ax.set_ylabel("Hazard Rate") + ax.set_title("Cause-Specific Hazard Functions") + ax.legend() + ax.grid(alpha=0.3) + + return fig, ax diff --git a/gen_surv/cphm.py b/gen_surv/cphm.py index 97f52a6..0f09c6d 100644 --- a/gen_surv/cphm.py +++ b/gen_surv/cphm.py @@ -1,28 +1,59 @@ +""" +Cox Proportional Hazards Model (CPHM) data generation. + +This module provides functions to generate survival data following the +Cox Proportional Hazards Model with various censoring mechanisms. +""" + +from typing import Literal + import numpy as np import pandas as pd -from gen_surv.validate import validate_gen_cphm_inputs -from gen_surv.censoring import runifcens, rexpocens +from numpy.typing import NDArray + +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens +from gen_surv.validation import validate_gen_cphm_inputs -def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): + +def generate_cphm_data( + n: int, + rfunc: CensoringFunc, + cens_par: float, + beta: float, + covariate_range: float, + seed: int | None = None, +) -> NDArray[np.float64]: """ Generate data from a Cox Proportional Hazards Model (CPHM). - Parameters: - - n (int): Number of samples to generate. - - rfunc (callable): Function to generate censoring times, must accept (size, cens_par). - - cens_par (float): Parameter passed to the censoring function. - - beta (float): Coefficient for the covariate. - - covar (float): Range for the covariate (uniformly sampled from [0, covar]). + Parameters + ---------- + n : int + Number of samples to generate. + rfunc : callable + Function to generate censoring times, must accept (size, cens_par). + cens_par : float + Parameter passed to the censoring function. + beta : float + Coefficient for the covariate. + covariate_range : float + Range for the covariate (uniformly sampled from [0, covariate_range]). + seed : int, optional + Random seed for reproducibility. - Returns: - - np.ndarray: Array with shape (n, 3): [time, status, covariate] + Returns + ------- + NDArray[np.float64] + Array with shape ``(n, 3)``: ``[time, status, X0]`` """ - data = np.zeros((n, 3)) + rng = np.random.default_rng(seed) + + data: NDArray[np.float64] = np.zeros((n, 3), dtype=float) for k in range(n): - z = np.random.uniform(0, covariate_range) - c = rfunc(1, cens_par)[0] - x = np.random.exponential(scale=1 / np.exp(beta * z)) + z = rng.uniform(0, covariate_range) + c = rfunc(1, cens_par, rng)[0] + x = rng.exponential(scale=1 / np.exp(beta * z)) time = min(x, c) status = int(x <= c) @@ -32,27 +63,53 @@ def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): return data -def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: +def gen_cphm( + n: int, + model_cens: Literal["uniform", "exponential"], + cens_par: float, + beta: float, + covariate_range: float, + seed: int | None = None, +) -> pd.DataFrame: """ - Convenience wrapper to generate CPHM survival data. + Generate survival data following a Cox Proportional Hazards Model. - Parameters: - - n (int): Number of observations. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Parameter for the censoring model. - - beta (float): Coefficient for the covariate. - - covar (float): Covariate range (uniform between 0 and covar). + Parameters + ---------- + n : int + Number of observations. + model_cens : {"uniform", "exponential"} + Type of censoring mechanism. + cens_par : float + Parameter for the censoring model. + beta : float + Coefficient for the covariate. + covariate_range : float + Upper bound for the covariate values (uniform between 0 and covariate_range). + seed : int, optional + Random seed for reproducibility. - Returns: - - pd.DataFrame: Columns are ["time", "status", "covariate"] - """ - validate_gen_cphm_inputs(n, model_cens, cens_par, covar) + Returns + ------- + pd.DataFrame + DataFrame with columns ["time", "status", "X0"] + - time: observed event or censoring time + - status: event indicator (1=event, 0=censored) + - X0: predictor variable - rfunc = { - "uniform": runifcens, - "exponential": rexpocens - }[model_cens] + Examples + -------- + >>> from gen_surv.cphm import gen_cphm + >>> df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> df.head() + time status X0 + 0 0.23 1.0 1.42 + 1 0.78 0.0 0.89 + ... + """ + validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) - data = generate_cphm_data(n, rfunc, cens_par, beta, covar) + rfunc = {"uniform": runifcens, "exponential": rexpocens}[model_cens] - return pd.DataFrame(data, columns=["time", "status", "covariate"]) + data = generate_cphm_data(n, rfunc, cens_par, beta, covariate_range, seed) + return pd.DataFrame(data, columns=["time", "status", "X0"]) diff --git a/gen_surv/export.py b/gen_surv/export.py new file mode 100644 index 0000000..526dc42 --- /dev/null +++ b/gen_surv/export.py @@ -0,0 +1,47 @@ +"""Data export utilities for gen_surv. + +This module provides helper functions to save generated +survival datasets in various formats. +""" + +from __future__ import annotations + +import os + +import pandas as pd +import pyreadr + +from .validation import ensure_in_choices + + +def export_dataset(df: pd.DataFrame, path: str, fmt: str | None = None) -> None: + """Save a DataFrame to disk. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing survival data. + path : str + File path to write to. The extension is used to infer the format + when ``fmt`` is ``None``. + fmt : {"csv", "json", "feather", "rds"}, optional + Format to use. If omitted, inferred from ``path``. + + Raises + ------ + ChoiceError + If the format is not one of the supported types. + """ + if fmt is None: + fmt = os.path.splitext(path)[1].lstrip(".").lower() + + ensure_in_choices(fmt, "fmt", {"csv", "json", "feather", "ft", "rds"}) + + if fmt == "csv": + df.to_csv(path, index=False) + elif fmt == "json": + df.to_json(path, orient="table") + elif fmt in {"feather", "ft"}: + df.reset_index(drop=True).to_feather(path) + elif fmt == "rds": + pyreadr.write_rds(path, df.reset_index(drop=True)) diff --git a/gen_surv/integration.py b/gen_surv/integration.py new file mode 100644 index 0000000..6967086 --- /dev/null +++ b/gen_surv/integration.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +from numpy.typing import NDArray + + +def to_sksurv( + df: pd.DataFrame, time_col: str = "time", event_col: str = "status" +) -> NDArray[np.void]: + """Convert a DataFrame to a scikit-survival structured array. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing survival data. + time_col : str, default "time" + Column storing durations. + event_col : str, default "status" + Column storing event indicators (1=event, 0=censored). + + Returns + ------- + numpy.ndarray + Structured array suitable for scikit-survival estimators. + + Notes + ----- + The ``sksurv`` package is imported lazily inside the function. It must be + installed separately, for instance with ``pip install scikit-survival``. + """ + + try: + from sksurv.util import Surv + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError("scikit-survival is required for this feature.") from exc + + return Surv.from_dataframe(event_col, time_col, df) diff --git a/gen_surv/interface.py b/gen_surv/interface.py index 549ad45..c01aafd 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -3,41 +3,108 @@ Example: >>> from gen_surv import generate - >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) """ -from typing import Any +from collections.abc import Callable +from typing import Dict, Literal + import pandas as pd -from gen_surv.cphm import gen_cphm +from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull from gen_surv.cmm import gen_cmm +from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull +from gen_surv.cphm import gen_cphm +from gen_surv.mixture import gen_mixture_cure +from gen_surv.piecewise import gen_piecewise_exponential from gen_surv.tdcm import gen_tdcm from gen_surv.thmm import gen_thmm -from gen_surv.aft import gen_aft_log_normal +from .validation import ValidationError, ensure_in_choices + +# Type definitions for model names +ModelType = Literal[ + "cphm", + "cmm", + "tdcm", + "thmm", + "aft_ln", + "aft_weibull", + "aft_log_logistic", + "competing_risks", + "competing_risks_weibull", + "mixture_cure", + "piecewise_exponential", +] + + +# Interface for generator callables +DataGenerator = Callable[..., pd.DataFrame] -_model_map = { + +# Map model names to their generator functions +_model_map: Dict[ModelType, DataGenerator] = { "cphm": gen_cphm, "cmm": gen_cmm, "tdcm": gen_tdcm, "thmm": gen_thmm, "aft_ln": gen_aft_log_normal, + "aft_weibull": gen_aft_weibull, + "aft_log_logistic": gen_aft_log_logistic, + "competing_risks": gen_competing_risks, + "competing_risks_weibull": gen_competing_risks_weibull, + "mixture_cure": gen_mixture_cure, + "piecewise_exponential": gen_piecewise_exponential, } -def generate(model: str, **kwargs: Any) -> pd.DataFrame: +def generate(model: ModelType, **kwargs: object) -> pd.DataFrame: """Generate survival data from a specific model. - Args: - model: Name of the generator to run. Must be one of ``cphm``, ``cmm``, - ``tdcm``, ``thmm`` or ``aft_ln``. - **kwargs: Arguments forwarded to the chosen generator. + Parameters + ---------- + model : ModelType + Name of the generator to run. Must be one of ``cphm``, ``cmm``, + ``tdcm``, ``thmm``, ``aft_ln``, ``aft_weibull``, ``aft_log_logistic``, + ``competing_risks``, ``competing_risks_weibull``, ``mixture_cure``, + or ``piecewise_exponential``. + **kwargs + Arguments forwarded to the chosen generator. These vary by model. + + - cphm: n, model_cens, cens_par, beta, covariate_range + - cmm: n, model_cens, cens_par, beta, covariate_range, rate + - tdcm: n, dist, corr, dist_par, model_cens, cens_par, beta, lam + - thmm: n, model_cens, cens_par, beta, covariate_range, rate + - aft_ln: n, beta, sigma, model_cens, cens_par, seed + - aft_weibull: n, beta, shape, scale, model_cens, cens_par, seed + - aft_log_logistic: n, beta, shape, scale, model_cens, cens_par, seed + - competing_risks: n, n_risks, baseline_hazards, betas, covariate_dist, etc. + - competing_risks_weibull: n, n_risks, shape_params, scale_params, betas, etc. + - mixture_cure: n, cure_fraction, baseline_hazard, betas_survival, + betas_cure, etc. + - piecewise_exponential: n, breakpoints, hazard_rates, betas, etc. + + Returns + ------- + pd.DataFrame + Simulated survival data with columns specific to the chosen model. + All models include time/duration and status columns. - Returns: - pd.DataFrame: Simulated survival data. + Raises + ------ + ChoiceError + If an unknown model name is provided. + + Examples + -------- + >>> from gen_surv import generate + >>> df = generate(model="cphm", n=100, beta=0.5, covariate_range=2.0, + ... model_cens="uniform", cens_par=1.0) + >>> df.head() """ - model = model.lower() - if model not in _model_map: - raise ValueError(f"Unknown model '{model}'. Choose from {list(_model_map.keys())}.") - - return _model_map[model](**kwargs) + ensure_in_choices(model, "model", _model_map.keys()) + try: + return _model_map[model](**kwargs) + except ValidationError as exc: + exc.args = (f"{exc} (while validating inputs for model '{model}')",) + raise exc diff --git a/gen_surv/mixture.py b/gen_surv/mixture.py new file mode 100644 index 0000000..e19e3e5 --- /dev/null +++ b/gen_surv/mixture.py @@ -0,0 +1,264 @@ +""" +Mixture Cure Models for survival data simulation. + +This module provides functions to generate survival data with a cure fraction, +i.e., a proportion of subjects who are immune to the event of interest. +""" + +from typing import Literal + +import numpy as np +import pandas as pd +from numpy.random import Generator +from numpy.typing import NDArray + +_TAIL_FRACTION: float = 0.1 +_SMOOTH_MIN_TAIL: int = 3 + +from ._covariates import generate_covariates, prepare_betas, set_covariate_params +from .censoring import rexpocens, runifcens +from .validation import ParameterError, ensure_positive, validate_gen_mixture_inputs + + +def _cure_status( + lp_cure: NDArray[np.float64], cure_fraction: float, rng: Generator +) -> NDArray[np.int64]: + cure_probs = 1 / ( + 1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure)) + ) + return rng.binomial(1, cure_probs).astype(np.int64) + + +def _survival_times( + cured: NDArray[np.int64], + lp_survival: NDArray[np.float64], + baseline_hazard: float, + max_time: float | None, + rng: Generator, +) -> NDArray[np.float64]: + n = cured.size + times = np.zeros(n, dtype=float) + non_cured = cured == 0 + adjusted_hazard = baseline_hazard * np.exp(lp_survival[non_cured]) + times[non_cured] = rng.exponential(scale=1 / adjusted_hazard) + if max_time is not None: + times[~non_cured] = max_time * 100 + else: + times[~non_cured] = np.inf + return times + + +def _apply_censoring( + survival_times: NDArray[np.float64], + model_cens: str, + cens_par: float, + max_time: float | None, + rng: Generator, +) -> tuple[NDArray[np.float64], NDArray[np.int64]]: + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(len(survival_times), cens_par, rng) + observed = np.minimum(survival_times, cens_times) + status = (survival_times <= cens_times).astype(int) + if max_time is not None: + over_max = observed > max_time + observed[over_max] = max_time + status[over_max] = 0 + return observed, status + + +def gen_mixture_cure( + n: int, + cure_fraction: float, + baseline_hazard: float = 0.5, + betas_survival: list[float] | None = None, + betas_cure: list[float] | None = None, + n_covariates: int = 2, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: dict[str, float] | None = None, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + max_time: float | None = 10.0, + seed: int | None = None, +) -> pd.DataFrame: + """ + Generate survival data with a cure fraction using a mixture cure model. + + Parameters + ---------- + n : int + Number of subjects. + cure_fraction : float + Baseline probability of being cured (immune to the event). + Should be between 0 and 1. + baseline_hazard : float, default=0.5 + Baseline hazard rate for the non-cured population. + betas_survival : list of float, optional + Coefficients for covariates in the survival component. + If None, generates random coefficients. + betas_cure : list of float, optional + Coefficients for covariates in the cure component. + If None, generates random coefficients. + n_covariates : int, default=2 + Number of covariates to generate if betas is None. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + max_time : float, optional, default=10.0 + Maximum simulation time. Set to None for no limit. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (1=event, 0=censored) + - "cured": Indicator of cure status (1=cured, 0=not cured) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.mixture import gen_mixture_cure + >>> + >>> # Generate data with 30% baseline cure fraction + >>> df = gen_mixture_cure( + ... n=100, + ... cure_fraction=0.3, + ... betas_survival=[0.8, -0.5], + ... betas_cure=[-0.5, 0.8], + ... seed=42 + ... ) + >>> + >>> # Check cure proportion + >>> print(f"Cured subjects: {df['cured'].mean():.2%}") + """ + rng = np.random.default_rng(seed) + validate_gen_mixture_inputs( + n, + cure_fraction, + baseline_hazard, + n_covariates, + model_cens, + cens_par, + max_time, + covariate_dist, + ) + covariate_params = set_covariate_params(covariate_dist, covariate_params) + betas_survival_arr, n_covariates = prepare_betas( + betas_survival, n_covariates, rng, name="betas_survival" + ) + betas_cure_arr, _ = prepare_betas( + betas_cure, n_covariates, rng, name="betas_cure", enforce_length=True + ) + X = generate_covariates(n, n_covariates, covariate_dist, covariate_params, rng) + lp_survival = X @ betas_survival_arr + lp_cure = X @ betas_cure_arr + cured = _cure_status(lp_cure, cure_fraction, rng) + survival_times = _survival_times(cured, lp_survival, baseline_hazard, max_time, rng) + observed_times, status = _apply_censoring( + survival_times, model_cens, cens_par, max_time, rng + ) + + data = pd.DataFrame( + {"id": np.arange(n), "time": observed_times, "status": status, "cured": cured} + ) + + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def cure_fraction_estimate( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + bandwidth: float = 0.1, +) -> float: + """ + Estimate the cure fraction from observed data using non-parametric methods. + + Parameters + ---------- + data : pd.DataFrame + DataFrame with survival data. + time_col : str, default="time" + Name of the time column. + status_col : str, default="status" + Name of the status column (1=event, 0=censored). + bandwidth : float, default=0.1 + Bandwidth parameter for smoothing the tail of the survival curve. + + Returns + ------- + float + Estimated cure fraction. + + Notes + ----- + This function uses a non-parametric approach to estimate the cure fraction + based on the plateau of the survival curve. It may not be accurate for + small sample sizes or heavy censoring. + """ + if time_col not in data.columns or status_col not in data.columns: + missing = [c for c in (time_col, status_col) if c not in data.columns] + raise ParameterError( + "data", + data.columns.tolist(), + f"missing required column(s): {', '.join(missing)}", + ) + ensure_positive(bandwidth, "bandwidth") + # Sort data by time + sorted_data = data.sort_values(by=time_col).copy() + + # Calculate Kaplan-Meier estimate + times = sorted_data[time_col].values + status = sorted_data[status_col].values + n = len(times) + + if n == 0: + return 0.0 + + # Calculate survival function + survival = np.ones(n) + + for i in range(n): + if i > 0: + survival[i] = survival[i - 1] + + # Count subjects at risk at this time + at_risk = n - i + + if status[i] == 1: # Event + survival[i] *= 1 - 1 / at_risk + + # Estimate cure fraction as the plateau of the survival curve + # Use the last portion of the survival curve if enough data points + tail_size = max(int(n * _TAIL_FRACTION), 1) + tail_survival = survival[-tail_size:] + + # Apply smoothing if there are enough data points + if tail_size > _SMOOTH_MIN_TAIL: + # Use kernel smoothing + weights = np.exp( + -((np.arange(tail_size) - tail_size + 1) ** 2) + / (2 * bandwidth * tail_size) ** 2 + ) + weights = weights / weights.sum() + cure_fraction = float(np.sum(tail_survival * weights)) + else: + # Just use the last survival probability + cure_fraction = float(survival[-1]) + + return cure_fraction diff --git a/gen_surv/piecewise.py b/gen_surv/piecewise.py new file mode 100644 index 0000000..8e56b60 --- /dev/null +++ b/gen_surv/piecewise.py @@ -0,0 +1,286 @@ +""" +Piecewise Exponential survival models. + +This module provides functions for generating survival data from piecewise +exponential distributions with time-dependent hazards. +""" + +from typing import Literal + +import numpy as np +import pandas as pd +from numpy.typing import NDArray + +from ._covariates import generate_covariates, prepare_betas, set_covariate_params +from .censoring import rexpocens, runifcens +from .validation import validate_gen_piecewise_inputs, validate_piecewise_params + + +def gen_piecewise_exponential( + n: int, + breakpoints: list[float], + hazard_rates: list[float], + betas: list[float] | NDArray[np.float64] | None = None, + n_covariates: int = 2, + covariate_dist: Literal["normal", "uniform", "binary"] = "normal", + covariate_params: dict[str, float] | None = None, + model_cens: Literal["uniform", "exponential"] = "uniform", + cens_par: float = 5.0, + seed: int | None = None, +) -> pd.DataFrame: + """ + Generate survival data using a piecewise exponential distribution. + + Parameters + ---------- + n : int + Number of subjects. + breakpoints : list of float + Time points where hazard rates change. Must be in ascending order. + The first interval is [0, breakpoints[0]), the second is [breakpoints[0], breakpoints[1]), etc. + hazard_rates : list of float + Hazard rates for each interval. Length should be len(breakpoints) + 1. + betas : list or array, optional + Coefficients for covariates. If None, generates random coefficients. + n_covariates : int, default=2 + Number of covariates to generate if betas is None. + covariate_dist : {"normal", "uniform", "binary"}, default="normal" + Distribution to generate covariates from. + covariate_params : dict, optional + Parameters for covariate distribution: + - "normal": {"mean": float, "std": float} + - "uniform": {"low": float, "high": float} + - "binary": {"p": float} + If None, uses defaults based on distribution. + model_cens : {"uniform", "exponential"}, default="uniform" + Censoring mechanism. + cens_par : float, default=5.0 + Parameter for censoring distribution. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - "id": Subject identifier + - "time": Time to event or censoring + - "status": Event indicator (1=event, 0=censored) + - "X0", "X1", ...: Covariates + + Examples + -------- + >>> from gen_surv.piecewise import gen_piecewise_exponential + >>> + >>> # Generate data with 3 intervals (increasing hazard) + >>> df = gen_piecewise_exponential( + ... n=100, + ... breakpoints=[1.0, 3.0], + ... hazard_rates=[0.2, 0.5, 1.0], + ... betas=[0.8, -0.5], + ... seed=42 + ... ) + """ + rng = np.random.default_rng(seed) + + validate_gen_piecewise_inputs( + n, + breakpoints, + hazard_rates, + n_covariates, + model_cens, + cens_par, + covariate_dist, + ) + covariate_params = set_covariate_params(covariate_dist, covariate_params) + + # Set default betas if not provided + betas, n_covariates = prepare_betas(betas, n_covariates, rng) + + # Generate covariates + X = generate_covariates(n, n_covariates, covariate_dist, covariate_params, rng) + + # Calculate linear predictor + linear_predictor = X @ betas + + # Generate survival times using piecewise exponential distribution + survival_times = np.zeros(n) + + for i in range(n): + # Adjust hazard rates by the covariate effect + adjusted_hazard_rates = [h * np.exp(linear_predictor[i]) for h in hazard_rates] + + # Generate random uniform between 0 and 1 + u = rng.uniform(0, 1) + + # Calculate survival time using inverse CDF method for piecewise exponential + remaining_time = -np.log(u) # Initial time remaining (for standard exponential) + total_time = 0.0 + + # Start with the first interval [0, breakpoints[0]) + interval_width = breakpoints[0] + hazard = adjusted_hazard_rates[0] + time_to_consume = remaining_time / hazard + + if time_to_consume < interval_width: + # Event occurs in first interval + survival_times[i] = time_to_consume + continue + + # Event occurs after first interval + total_time += interval_width + remaining_time -= hazard * interval_width + + # Go through middle intervals [breakpoints[j-1], breakpoints[j]) + for j in range(1, len(breakpoints)): + interval_width = breakpoints[j] - breakpoints[j - 1] + hazard = adjusted_hazard_rates[j] + time_to_consume = remaining_time / hazard + + if time_to_consume < interval_width: + # Event occurs in this interval + survival_times[i] = total_time + time_to_consume + break + + # Event occurs after this interval + total_time += interval_width + remaining_time -= hazard * interval_width + + # If we've gone through all intervals and still no event, + # use the last hazard rate for the remainder + if remaining_time > 0: + hazard = adjusted_hazard_rates[-1] + survival_times[i] = total_time + remaining_time / hazard + + # Generate censoring times + rfunc = runifcens if model_cens == "uniform" else rexpocens + cens_times = rfunc(n, cens_par, rng) + + # Determine observed time and status + observed_times = np.minimum(survival_times, cens_times) + status = (survival_times <= cens_times).astype(int) + + # Create DataFrame + data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) + + # Add covariates + for j in range(n_covariates): + data[f"X{j}"] = X[:, j] + + return data + + +def piecewise_hazard_function( + t: float | NDArray[np.float64], + breakpoints: list[float], + hazard_rates: list[float], +) -> float | NDArray[np.float64]: + """ + Calculate the hazard function value at time t for a piecewise exponential distribution. + + Parameters + ---------- + t : float or array + Time point(s) at which to evaluate the hazard function. + breakpoints : list of float + Time points where hazard rates change. + hazard_rates : list of float + Hazard rates for each interval. + + Returns + ------- + float or array + Hazard function value(s) at time t. + """ + validate_piecewise_params(breakpoints, hazard_rates) + + # Convert scalar input to array for consistent processing + scalar_input = np.isscalar(t) + t_array = np.atleast_1d(t) + result = np.zeros_like(t_array) + + # Assign hazard rates based on time intervals + result[t_array < 0] = 0 # Hazard is 0 for negative times + + # First interval: [0, breakpoints[0]) + mask = (t_array >= 0) & (t_array < breakpoints[0]) + result[mask] = hazard_rates[0] + + # Middle intervals: [breakpoints[j-1], breakpoints[j]) + for j in range(1, len(breakpoints)): + mask = (t_array >= breakpoints[j - 1]) & (t_array < breakpoints[j]) + result[mask] = hazard_rates[j] + + # Last interval: [breakpoints[-1], infinity) + mask = t_array >= breakpoints[-1] + result[mask] = hazard_rates[-1] + + return result[0] if scalar_input else result + + +def piecewise_survival_function( + t: float | NDArray[np.float64], + breakpoints: list[float], + hazard_rates: list[float], +) -> float | NDArray[np.float64]: + """ + Calculate the survival function at time t for a piecewise exponential distribution. + + Parameters + ---------- + t : float or array + Time point(s) at which to evaluate the survival function. + breakpoints : list of float + Time points where hazard rates change. + hazard_rates : list of float + Hazard rates for each interval. + + Returns + ------- + float or array + Survival function value(s) at time t. + """ + validate_piecewise_params(breakpoints, hazard_rates) + + # Convert scalar input to array for consistent processing + scalar_input = np.isscalar(t) + t_array = np.atleast_1d(t) + result = np.ones_like(t_array) + + # For each time point, calculate the survival function + for i, ti in enumerate(t_array): + if ti <= 0: + continue # Survival probability is 1 at time 0 or earlier + + cumulative_hazard = 0.0 + + # First interval: [0, min(ti, breakpoints[0])) + first_interval_end = min(ti, breakpoints[0]) if breakpoints else ti + cumulative_hazard += hazard_rates[0] * first_interval_end + + if ti <= breakpoints[0]: + result[i] = np.exp(-cumulative_hazard) + continue + + # Middle intervals: [breakpoints[j-1], min(ti, breakpoints[j])) + for j in range(1, len(breakpoints)): + if ti <= breakpoints[j - 1]: + break + + interval_start = breakpoints[j - 1] + interval_end = min(ti, breakpoints[j]) + interval_width = interval_end - interval_start + + cumulative_hazard += hazard_rates[j] * interval_width + + if ti <= breakpoints[j]: + break + + # Last interval: [breakpoints[-1], ti) + if ti > breakpoints[-1]: + last_interval_width = ti - breakpoints[-1] + cumulative_hazard += hazard_rates[-1] * last_interval_width + + result[i] = np.exp(-cumulative_hazard) + + return result[0] if scalar_input else result diff --git a/gen_surv/sklearn_adapter.py b/gen_surv/sklearn_adapter.py new file mode 100644 index 0000000..ce27c9d --- /dev/null +++ b/gen_surv/sklearn_adapter.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +import pandas as pd + +from .interface import ModelType, generate +from .validation import ensure_in_choices + + +class BaseEstimatorProto(Protocol): + """Protocol capturing the minimal scikit-learn estimator interface.""" + + def get_params(self, deep: bool = ...) -> dict[str, object]: ... + + def set_params(self, **params: object) -> "BaseEstimatorProto": ... + + +if TYPE_CHECKING: # pragma: no cover - import for type checkers only + from sklearn.base import BaseEstimator as SklearnBase +else: # pragma: no cover - runtime import with fallback + try: + from sklearn.base import BaseEstimator as SklearnBase + except Exception: + + class SklearnBase: # noqa: D401 - simple runtime stub + """Minimal stub if scikit-learn is not installed.""" + + def get_params(self, deep: bool = True) -> dict[str, object]: + return {} + + def set_params(self, **params: object) -> "SklearnBase": + return self + + +class GenSurvDataGenerator(SklearnBase, BaseEstimatorProto): + """Scikit-learn compatible wrapper around :func:`gen_surv.generate`.""" + + def __init__( + self, model: ModelType, return_type: str = "df", **kwargs: object + ) -> None: + ensure_in_choices(return_type, "return_type", {"df", "dict"}) + self.model = model + self.return_type = return_type + self.kwargs = kwargs + + def fit( + self, X: pd.DataFrame | None = None, y: pd.Series | None = None + ) -> "GenSurvDataGenerator": + return self + + def transform( + self, X: pd.DataFrame | None = None + ) -> pd.DataFrame | dict[str, list[object]]: + df = generate(self.model, **self.kwargs) + if self.return_type == "df": + return df + if self.return_type == "dict": + return df.to_dict(orient="list") + raise AssertionError("Unreachable due to validation") + + def fit_transform( + self, + X: pd.DataFrame | None = None, + y: pd.Series | None = None, + **fit_params: object, + ) -> pd.DataFrame | dict[str, list[object]]: + return self.fit(X, y).transform(X) diff --git a/gen_surv/summary.py b/gen_surv/summary.py new file mode 100644 index 0000000..1150793 --- /dev/null +++ b/gen_surv/summary.py @@ -0,0 +1,508 @@ +""" +Utilities for summarizing and validating survival datasets. + +This module provides functions to summarize survival data, +check data quality, and identify potential issues. +""" + +from typing import Any + +import pandas as pd + +from .validation import ParameterError + + +def summarize_survival_dataset( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + id_col: str | None = None, + covariate_cols: list[str] | None = None, + verbose: bool = True, +) -> dict[str, Any]: + """ + Generate a comprehensive summary of a survival dataset. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing survival data. + time_col : str, default="time" + Name of the column containing time-to-event values. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + id_col : str, optional + Name of the column containing subject identifiers. + covariate_cols : list of str, optional + List of column names to include as covariates in the summary. + If None, all columns except time_col, status_col, and id_col are considered. + verbose : bool, default=True + Whether to print the summary to console. + + Returns + ------- + dict[str, Any] + Dictionary containing all summary statistics. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.summary import summarize_survival_dataset + >>> + >>> # Generate example data + >>> df = generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> + >>> # Summarize the dataset + >>> summary = summarize_survival_dataset(df) + """ + # Validate input columns + for col in [time_col, status_col]: + if col not in data.columns: + raise ParameterError("column", col, "not found in data") + + if id_col is not None and id_col not in data.columns: + raise ParameterError("id_col", id_col, "not found in data") + + # Determine covariate columns + if covariate_cols is None: + exclude_cols = {time_col, status_col} + if id_col is not None: + exclude_cols.add(id_col) + covariate_cols = [col for col in data.columns if col not in exclude_cols] + else: + missing_cols = [col for col in covariate_cols if col not in data.columns] + if missing_cols: + raise ParameterError("covariate_cols", missing_cols, "not found in data") + + # Basic dataset information + n_subjects = len(data) + if id_col is not None: + n_unique_ids = data[id_col].nunique() + else: + n_unique_ids = n_subjects + + # Event information + n_events = data[status_col].sum() + n_censored = n_subjects - n_events + event_rate = n_events / n_subjects + + # Time statistics + time_min = data[time_col].min() + time_max = data[time_col].max() + time_mean = data[time_col].mean() + time_median = data[time_col].median() + + # Data quality checks + n_missing_time = data[time_col].isna().sum() + n_missing_status = data[status_col].isna().sum() + n_negative_time = (data[time_col] < 0).sum() + n_invalid_status = data[~data[status_col].isin([0, 1])].shape[0] + + # Covariate summaries + covariate_stats = {} + for col in covariate_cols: + col_data = data[col] + is_numeric = pd.api.types.is_numeric_dtype(col_data) + + if is_numeric: + covariate_stats[col] = { + "type": "numeric", + "min": col_data.min(), + "max": col_data.max(), + "mean": col_data.mean(), + "median": col_data.median(), + "std": col_data.std(), + "missing": col_data.isna().sum(), + "unique_values": col_data.nunique(), + } + else: + # Categorical/string + covariate_stats[col] = { + "type": "categorical", + "n_categories": col_data.nunique(), + "top_categories": col_data.value_counts().head(5).to_dict(), + "missing": col_data.isna().sum(), + } + + # Compile the summary + summary = { + "dataset_info": { + "n_subjects": n_subjects, + "n_unique_ids": n_unique_ids, + "n_covariates": len(covariate_cols), + }, + "event_info": { + "n_events": n_events, + "n_censored": n_censored, + "event_rate": event_rate, + }, + "time_info": { + "min": time_min, + "max": time_max, + "mean": time_mean, + "median": time_median, + }, + "data_quality": { + "missing_time": n_missing_time, + "missing_status": n_missing_status, + "negative_time": n_negative_time, + "invalid_status": n_invalid_status, + "overall_quality": ( + "good" + if ( + n_missing_time + + n_missing_status + + n_negative_time + + n_invalid_status + ) + == 0 + else "issues_detected" + ), + }, + "covariates": covariate_stats, + } + + # Print summary if requested + if verbose: + _print_summary(summary, time_col, status_col, id_col, covariate_cols) + + return summary + + +def check_survival_data_quality( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + id_col: str | None = None, + min_time: float = 0.0, + max_time: float | None = None, + status_values: list[int] | None = None, + fix_issues: bool = False, +) -> tuple[pd.DataFrame, dict[str, Any]]: + """ + Check for common issues in survival data and optionally fix them. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing survival data. + time_col : str, default="time" + Name of the column containing time-to-event values. + status_col : str, default="status" + Name of the column containing event indicators. + id_col : str, optional + Name of the column containing subject identifiers. + min_time : float, default=0.0 + Minimum acceptable value for time column. + max_time : float, optional + Maximum acceptable value for time column. + status_values : list of int, optional + List of valid status values. Default is [0, 1]. + fix_issues : bool, default=False + Whether to attempt fixing issues (returns a modified DataFrame). + + Returns + ------- + tuple[pd.DataFrame, dict[str, Any]] + Tuple containing (possibly fixed) DataFrame and issues report. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.summary import check_survival_data_quality + >>> + >>> # Generate example data with some issues + >>> df = generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> # Introduce some issues + >>> df.loc[0, "time"] = np.nan + >>> df.loc[1, "status"] = 2 # Invalid status + >>> + >>> # Check and fix issues + >>> fixed_df, issues = check_survival_data_quality(df, fix_issues=True) + >>> print(issues) + """ + if status_values is None: + status_values = [0, 1] + + # Make a copy to avoid modifying the original + if fix_issues: + data = data.copy() + + # Initialize issues report + issues: dict[str, dict[str, int | None]] = { + "missing_data": {"time": 0, "status": 0, "id": 0 if id_col else None}, + "invalid_values": { + "negative_time": 0, + "excessive_time": 0, + "invalid_status": 0, + }, + "duplicates": {"duplicate_rows": 0, "duplicate_ids": 0 if id_col else None}, + "modifications": {"rows_dropped": 0, "values_fixed": 0}, + } + + # Check for missing values + issues["missing_data"]["time"] = data[time_col].isna().sum() + issues["missing_data"]["status"] = data[status_col].isna().sum() + if id_col: + issues["missing_data"]["id"] = data[id_col].isna().sum() + + # Check for invalid values + issues["invalid_values"]["negative_time"] = (data[time_col] < min_time).sum() + if max_time is not None: + issues["invalid_values"]["excessive_time"] = (data[time_col] > max_time).sum() + issues["invalid_values"]["invalid_status"] = data[ + ~data[status_col].isin(status_values) + ].shape[0] + + # Check for duplicates + issues["duplicates"]["duplicate_rows"] = data.duplicated().sum() + if id_col: + issues["duplicates"]["duplicate_ids"] = data[id_col].duplicated().sum() + + # Fix issues if requested + if fix_issues: + original_rows = len(data) + modified_values = 0 + + # Handle missing values + data = data.dropna(subset=[time_col, status_col]) + + # Handle invalid values + if min_time > 0: + # Set negative or too small times to min_time + mask = data[time_col] < min_time + if mask.any(): + data.loc[mask, time_col] = min_time + modified_values += mask.sum() + + if max_time is not None: + # Cap excessively large times + mask = data[time_col] > max_time + if mask.any(): + data.loc[mask, time_col] = max_time + modified_values += mask.sum() + + # Fix invalid status values + mask = ~data[status_col].isin(status_values) + if mask.any(): + # Default to censored (0) for invalid status + data.loc[mask, status_col] = 0 + modified_values += mask.sum() + + # Remove duplicates + data = data.drop_duplicates() + + # Update modification counts + issues["modifications"]["rows_dropped"] = original_rows - len(data) + issues["modifications"]["values_fixed"] = modified_values + + return data, issues + + +def _print_summary( + summary: dict[str, Any], + time_col: str, + status_col: str, + id_col: str | None, + covariate_cols: list[str], +) -> None: + """ + Print a formatted summary of survival data. + + Parameters + ---------- + summary : dict[str, Any] + Summary dictionary from summarize_survival_dataset. + time_col : str + Name of the time column. + status_col : str + Name of the status column. + id_col : str, optional + Name of the ID column. + covariate_cols : list[str] + List of covariate column names. + """ + print("=" * 60) + print("SURVIVAL DATASET SUMMARY") + print("=" * 60) + + # Dataset info + print("\nDATASET INFORMATION:") + print(f" Subjects: {summary['dataset_info']['n_subjects']}") + if id_col: + print(f" Unique IDs: {summary['dataset_info']['n_unique_ids']}") + print(f" Covariates: {summary['dataset_info']['n_covariates']}") + + # Event info + print("\nEVENT INFORMATION:") + print( + f" Events: {summary['event_info']['n_events']} " + + f"({summary['event_info']['event_rate']:.1%})" + ) + print( + f" Censored: {summary['event_info']['n_censored']} " + + f"({1 - summary['event_info']['event_rate']:.1%})" + ) + + # Time info + print(f"\nTIME VARIABLE ({time_col}):") + print( + f" Range: {summary['time_info']['min']:.2f} to {summary['time_info']['max']:.2f}" + ) + print(f" Mean: {summary['time_info']['mean']:.2f}") + print(f" Median: {summary['time_info']['median']:.2f}") + + # Data quality + print("\nDATA QUALITY:") + quality_issues = ( + summary["data_quality"]["missing_time"] + + summary["data_quality"]["missing_status"] + + summary["data_quality"]["negative_time"] + + summary["data_quality"]["invalid_status"] + ) + + if quality_issues == 0: + print(" βœ“ No issues detected") + else: + print(" βœ— Issues detected:") + if summary["data_quality"]["missing_time"] > 0: + print( + f" - Missing time values: {summary['data_quality']['missing_time']}" + ) + if summary["data_quality"]["missing_status"] > 0: + print( + f" - Missing status values: {summary['data_quality']['missing_status']}" + ) + if summary["data_quality"]["negative_time"] > 0: + print( + f" - Negative time values: {summary['data_quality']['negative_time']}" + ) + if summary["data_quality"]["invalid_status"] > 0: + print( + f" - Invalid status values: {summary['data_quality']['invalid_status']}" + ) + + # Covariates + print("\nCOVARIATES:") + if not covariate_cols: + print(" No covariates found") + else: + for col, stats in summary["covariates"].items(): + print(f" {col}:") + if stats["type"] == "numeric": + print(" Type: Numeric") + print(f" Range: {stats['min']:.2f} to {stats['max']:.2f}") + print(f" Mean: {stats['mean']:.2f}") + print(f" Missing: {stats['missing']}") + else: + print(" Type: Categorical") + print(f" Categories: {stats['n_categories']}") + print(f" Missing: {stats['missing']}") + + print("\n" + "=" * 60) + + +def compare_survival_datasets( + datasets: dict[str, pd.DataFrame], + time_col: str = "time", + status_col: str = "status", + covariate_cols: list[str] | None = None, +) -> pd.DataFrame: + """ + Compare multiple survival datasets and summarize their differences. + + Parameters + ---------- + datasets : dict[str, pd.DataFrame] + Dictionary mapping dataset names to DataFrames. + time_col : str, default="time" + Name of the time column in each dataset. + status_col : str, default="status" + Name of the status column in each dataset. + covariate_cols : List[str], optional + List of covariate columns to compare. If None, compares all common columns. + + Returns + ------- + pd.DataFrame + Comparison table with datasets as columns and metrics as rows. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.summary import compare_survival_datasets + >>> + >>> # Generate datasets with different parameters + >>> datasets = { + ... "CPHM": generate(model="cphm", n=100, model_cens="uniform", + ... cens_par=1.0, beta=0.5, covariate_range=2.0), + ... "Weibull AFT": generate(model="aft_weibull", n=100, beta=[0.5], + ... shape=1.5, scale=1.0, model_cens="uniform", cens_par=1.0) + ... } + >>> + >>> # Compare datasets + >>> comparison = compare_survival_datasets(datasets) + >>> print(comparison) + """ + if not datasets: + raise ParameterError("datasets", datasets, "at least one dataset is required") + + # Find common columns if covariate_cols not specified + if covariate_cols is None: + all_columns = [set(df.columns) for df in datasets.values()] + common_columns = set.intersection(*all_columns) + common_columns -= {time_col, status_col} # Remove time and status + covariate_cols = sorted(list(common_columns)) + + # Calculate summaries for each dataset + summaries = {} + for name, data in datasets.items(): + summaries[name] = summarize_survival_dataset( + data, time_col, status_col, covariate_cols=covariate_cols, verbose=False + ) + + # Construct the comparison DataFrame + comparison_data = {} + + # Dataset info + comparison_data["n_subjects"] = { + name: summary["dataset_info"]["n_subjects"] + for name, summary in summaries.items() + } + comparison_data["n_events"] = { + name: summary["event_info"]["n_events"] for name, summary in summaries.items() + } + comparison_data["event_rate"] = { + name: summary["event_info"]["event_rate"] for name, summary in summaries.items() + } + + # Time info + comparison_data["time_min"] = { + name: summary["time_info"]["min"] for name, summary in summaries.items() + } + comparison_data["time_max"] = { + name: summary["time_info"]["max"] for name, summary in summaries.items() + } + comparison_data["time_mean"] = { + name: summary["time_info"]["mean"] for name, summary in summaries.items() + } + comparison_data["time_median"] = { + name: summary["time_info"]["median"] for name, summary in summaries.items() + } + + # Covariate info (means for numeric) + for col in covariate_cols: + for name, summary in summaries.items(): + if col in summary["covariates"]: + col_stats = summary["covariates"][col] + if col_stats["type"] == "numeric": + if f"{col}_mean" not in comparison_data: + comparison_data[f"{col}_mean"] = {} + comparison_data[f"{col}_mean"][name] = col_stats["mean"] + + # Create the DataFrame + comparison_df = pd.DataFrame(comparison_data).T + + return comparison_df diff --git a/gen_surv/tdcm.py b/gen_surv/tdcm.py index b349137..ec54098 100644 --- a/gen_surv/tdcm.py +++ b/gen_surv/tdcm.py @@ -1,76 +1,131 @@ +from typing import Sequence + import numpy as np import pandas as pd -from gen_surv.validate import validate_gen_tdcm_inputs +from numpy.typing import NDArray + from gen_surv.bivariate import sample_bivariate_distribution -from gen_surv.censoring import runifcens, rexpocens +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens +from gen_surv.validation import validate_gen_tdcm_inputs -def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): - """ - Generate censored TDCM observations. - - Parameters: - - n (int): Number of individuals - - dist_par (list): Not directly used here (kept for API compatibility) - - model_cens (str): "uniform" or "exponential" - - cens_par (float): Parameter for the censoring model - - beta (list): Length-2 list of regression coefficients - - lam (float): Rate parameter - - b (np.ndarray): Covariate matrix with 2 columns [., z1] - - Returns: - - np.ndarray: Shape (n, 6) with columns: - [id, start, stop, status, covariate1 (z1), covariate2 (z2)] + +def generate_censored_observations( + n: int, + dist_par: Sequence[float], + model_cens: str, + cens_par: float, + beta: Sequence[float], + lam: float, + b: NDArray[np.float64], +) -> NDArray[np.float64]: + """Generate censored TDCM observations. + + Parameters + ---------- + n : int + Number of individuals. + dist_par : Sequence[float] + Not directly used here (kept for API compatibility). + model_cens : {"uniform", "exponential"} + Censoring model. + cens_par : float + Parameter for the censoring model. + beta : Sequence[float] + Length-2 list of regression coefficients. + lam : float + Rate parameter. + b : NDArray[np.float64] + Covariate matrix with two columns ``[., z1]``. + + Returns + ------- + NDArray[np.float64] + Array of shape ``(n, 6)`` with columns + ``[id, start, stop, status, covariate1 (z1), covariate2 (z2)]``. """ - rfunc = runifcens if model_cens == "uniform" else rexpocens - observations = np.zeros((n, 6)) + rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens - for k in range(n): - z1 = b[k, 1] - c = rfunc(1, cens_par)[0] - u = np.random.uniform() + z1 = b[:, 1] + x = lam * b[:, 0] * np.exp(beta[0] * z1) + u = np.random.uniform(size=n) + c = rfunc(n, cens_par) - # Determine path based on u threshold - threshold = 1 - np.exp(-lam * b[k, 0] * np.exp(beta[0] * z1)) - if u < threshold: - t = -np.log(1 - u) / (lam * np.exp(beta[0] * z1)) - z2 = 0 - else: - t = ( - -np.log(1 - u) - + lam * b[k, 0] * np.exp(beta[0] * z1) * (1 - np.exp(beta[1])) - ) / (lam * np.exp(beta[0] * z1 + beta[1])) - z2 = 1 + threshold = 1 - np.exp(-x) + exp_b0_z1 = np.exp(beta[0] * z1) + log_term = -np.log(1 - u) + t1 = log_term / (lam * exp_b0_z1) + t2 = (log_term + x * (1 - np.exp(beta[1]))) / (lam * np.exp(beta[0] * z1 + beta[1])) + mask = u < threshold + t = np.where(mask, t1, t2) + z2 = (~mask).astype(float) - time = min(t, c) - status = int(t <= c) + time = np.minimum(t, c) + status = (t <= c).astype(float) - observations[k] = [k + 1, 0, time, status, z1, z2] + ids = np.arange(1, n + 1, dtype=float) + zeros = np.zeros(n, dtype=float) + return np.column_stack((ids, zeros, time, status, z1, z2)) - return observations +def gen_tdcm( + n: int, + dist: str, + corr: float, + dist_par: Sequence[float], + model_cens: str, + cens_par: float, + beta: Sequence[float], + lam: float, +) -> pd.DataFrame: + """Generate TDCM (Time-Dependent Covariate Model) survival data. -def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): - """ - Generate TDCM (Time-Dependent Covariate Model) survival data. - - Parameters: - - n (int): Number of individuals. - - dist (str): "weibull" or "exponential". - - corr (float): Correlation coefficient. - - dist_par (list): Distribution parameters. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Censoring parameter. - - beta (list): Length-2 regression coefficients. - - lam (float): Lambda rate parameter. - - Returns: - - pd.DataFrame: Columns are ["id", "start", "stop", "status", "covariate", "tdcov"] + Parameters + ---------- + n : int + Number of individuals. + dist : {"weibull", "exponential"} + Type of marginal distributions. + corr : float + Correlation coefficient between covariates. + dist_par : Sequence[float] + Distribution parameters. + model_cens : {"uniform", "exponential"} + Censoring model. + cens_par : float + Censoring parameter. + beta : Sequence[float] + Length-2 regression coefficients. + lam : float + Lambda rate parameter. + + Returns + ------- + pd.DataFrame + Columns are ``["id", "start", "stop", "status", "covariate", "tdcov"]``. + + Examples + -------- + >>> from gen_surv.tdcm import gen_tdcm + >>> df = gen_tdcm( + ... n=5, + ... dist="exponential", + ... corr=0.3, + ... dist_par=[0.5, 1.0], + ... model_cens="uniform", + ... cens_par=2.0, + ... beta=[0.1, 0.2], + ... lam=0.5, + ... ) """ validate_gen_tdcm_inputs(n, dist, corr, dist_par, model_cens, cens_par, beta, lam) # Generate covariate matrix from bivariate distribution b = sample_bivariate_distribution(n, dist, corr, dist_par) - data = generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b) + data = generate_censored_observations( + n, dist_par, model_cens, cens_par, beta, lam, b + ) - return pd.DataFrame(data, columns=["id", "start", "stop", "status", "covariate", "tdcov"]) + return pd.DataFrame( + data, columns=["id", "start", "stop", "status", "covariate", "tdcov"] + ) diff --git a/gen_surv/thmm.py b/gen_surv/thmm.py index 22feebf..4b317f8 100644 --- a/gen_surv/thmm.py +++ b/gen_surv/thmm.py @@ -1,9 +1,26 @@ +from typing import Sequence, TypedDict + import numpy as np import pandas as pd -from gen_surv.validate import validate_gen_thmm_inputs -from gen_surv.censoring import runifcens, rexpocens -def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: +from gen_surv.censoring import CensoringFunc, rexpocens, runifcens +from gen_surv.validation import validate_gen_thmm_inputs + + +class TransitionTimes(TypedDict): + c: float + t12: float + t13: float + t23: float + + +def calculate_transitions( + z1: float, + cens_par: float, + beta: Sequence[float], + rate: Sequence[float], + rfunc: CensoringFunc, +) -> TransitionTimes: """ Calculate transition and censoring times for THMM. @@ -29,27 +46,54 @@ def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rf return {"c": c, "t12": t12, "t13": t13, "t23": t23} -def gen_thmm(n, model_cens, cens_par, beta, covar, rate): - """ - Generate THMM (Time-Homogeneous Markov Model) survival data. +def gen_thmm( + n: int, + model_cens: str, + cens_par: float, + beta: Sequence[float], + covariate_range: float, + rate: Sequence[float], +) -> pd.DataFrame: + """Generate THMM (Time-Homogeneous Markov Model) survival data. - Parameters: - - n (int): Number of individuals. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Censoring parameter. - - beta (list): Length-3 regression coefficients. - - covar (float): Covariate upper bound. - - rate (list): Length-3 transition rates. + Parameters + ---------- + n : int + Number of individuals. + model_cens : {"uniform", "exponential"} + Censoring model. + cens_par : float + Censoring parameter. + beta : Sequence[float] + Length-3 regression coefficients. + covariate_range : float + Upper bound for the covariate values. + rate : Sequence[float] + Length-3 transition rates. - Returns: - - pd.DataFrame: Columns = ["id", "time", "state", "covariate"] + Returns + ------- + pd.DataFrame + Columns = ``["id", "time", "state", "X0"]``. + + Examples + -------- + >>> from gen_surv.thmm import gen_thmm + >>> df = gen_thmm( + ... n=3, + ... model_cens="uniform", + ... cens_par=5.0, + ... beta=[0.1, 0.2, 0.3], + ... covariate_range=1.0, + ... rate=[0.1, 0.1, 0.2], + ... ) """ - validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covar, rate) - rfunc = runifcens if model_cens == "uniform" else rexpocens + validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) + rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens records = [] for k in range(n): - z1 = np.random.uniform(0, covar) + z1 = np.random.uniform(0, covariate_range) trans = calculate_transitions(z1, cens_par, beta, rate, rfunc) t12, t13, c = trans["t12"], trans["t13"], trans["c"] @@ -63,4 +107,4 @@ def gen_thmm(n, model_cens, cens_par, beta, covar, rate): records.append([k + 1, time, state, z1]) - return pd.DataFrame(records, columns=["id", "time", "state", "covariate"]) + return pd.DataFrame(records, columns=["id", "time", "state", "X0"]) diff --git a/gen_surv/validate.py b/gen_surv/validate.py index 427f5c3..f15f291 100644 --- a/gen_surv/validate.py +++ b/gen_surv/validate.py @@ -1,193 +1,6 @@ -def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): - """ - Validates input parameters for CPHM data generation. +"""Compatibility wrapper for validation utilities. - Parameters: - - n (int): Number of data points to generate. - - model_cens (str): Censoring model, must be "uniform" or "exponential". - - cens_par (float): Parameter for the censoring model, must be > 0. - - covar (float): Covariate value, must be > 0. +This module re-exports symbols from :mod:`gen_surv.validation`. +""" - Raises: - - ValueError: If any input is invalid. - """ - if n <= 0: - raise ValueError("Argument 'n' must be greater than 0") - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") - if cens_par <= 0: - raise ValueError("Argument 'cens_par' must be greater than 0") - if covar <= 0: - raise ValueError("Argument 'covar' must be greater than 0") - - -def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): - """ - Validate inputs for generating CMM (Continuous-Time Markov Model) data. - - Parameters: - - n (int): Number of individuals. - - model_cens (str): Censoring model, must be "uniform" or "exponential". - - cens_par (float): Parameter for censoring distribution, must be > 0. - - beta (list): Regression coefficients, must have length 3. - - covar (float): Covariate value, must be > 0. - - rate (list): Transition rates, must have length 6. - - Raises: - - ValueError: If any parameter is invalid. - """ - if n <= 0: - raise ValueError("Argument 'n' must be greater than 0") - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") - if cens_par <= 0: - raise ValueError("Argument 'cens_par' must be greater than 0") - if len(beta) != 3: - raise ValueError("Argument 'beta' must be a list of length 3") - if covar <= 0: - raise ValueError("Argument 'covar' must be greater than 0") - if len(rate) != 6: - raise ValueError("Argument 'rate' must be a list of length 6") - - -def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, - model_cens: str, cens_par: float, beta: list, lam: float): - """ - Validate inputs for generating TDCM (Time-Dependent Covariate Model) data. - - Parameters: - - n (int): Number of observations. - - dist (str): "weibull" or "exponential". - - corr (float): Correlation coefficient. - - dist_par (list): Distribution parameters. - - model_cens (str): "uniform" or "exponential". - - cens_par (float): Censoring parameter. - - beta (list): Length-2 list of regression coefficients. - - lam (float): Lambda parameter, must be > 0. - - Raises: - - ValueError: For any invalid input. - """ - if n <= 0: - raise ValueError("Argument 'n' must be greater than 0") - - if dist not in {"weibull", "exponential"}: - raise ValueError( - "Argument 'dist' must be one of 'weibull' or 'exponential'") - - if dist == "weibull": - if not (0 < corr <= 1): - raise ValueError("With dist='weibull', 'corr' must be in (0,1]") - if len(dist_par) != 4 or any(p <= 0 for p in dist_par): - raise ValueError( - "With dist='weibull', 'dist_par' must be a positive list of length 4") - - if dist == "exponential": - if not (-1 <= corr <= 1): - raise ValueError( - "With dist='exponential', 'corr' must be in [-1,1]") - if len(dist_par) != 2 or any(p <= 0 for p in dist_par): - raise ValueError( - "With dist='exponential', 'dist_par' must be a positive list of length 2") - - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") - - if cens_par <= 0: - raise ValueError("Argument 'cens_par' must be greater than 0") - - if not isinstance(beta, list) or len(beta) != 3: - raise ValueError("Argument 'beta' must be a list of length 3") - - if lam <= 0: - raise ValueError("Argument 'lambda' must be greater than 0") - - -def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): - """ - Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. - - Parameters: - - n (int): Number of samples, must be > 0. - - model_cens (str): Must be "uniform" or "exponential". - - cens_par (float): Must be > 0. - - beta (list): List of length 3 (regression coefficients). - - covar (float): Positive covariate value. - - rate (list): List of length 3 (transition rates). - - Raises: - - ValueError if any input is invalid. - """ - if not isinstance(n, int) or n <= 0: - raise ValueError("Argument 'n' must be a positive integer.") - - if model_cens not in {"uniform", "exponential"}: - raise ValueError( - "Argument 'model_cens' must be one of 'uniform' or 'exponential'") - - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("Argument 'cens_par' must be a positive number.") - - if not isinstance(beta, list) or len(beta) != 3: - raise ValueError("Argument 'beta' must be a list of length 3.") - - if not isinstance(covar, (int, float)) or covar <= 0: - raise ValueError("Argument 'covar' must be greater than 0.") - - if not isinstance(rate, list) or len(rate) != 3: - raise ValueError("Argument 'rate' must be a list of length 3.") - - -def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): - """ - Validate inputs for the sample_bivariate_distribution function. - - Parameters: - - n (int): Number of samples to generate. - - dist (str): Must be "weibull" or "exponential". - - corr (float): Must be between -1 and 1. - - dist_par (list): Must contain positive values, and correct length for the distribution. - - Raises: - - ValueError if any input is invalid. - """ - if not isinstance(n, int) or n <= 0: - raise ValueError("Argument 'n' must be a positive integer.") - - if dist not in {"weibull", "exponential"}: - raise ValueError("Argument 'dist' must be one of 'weibull' or 'exponential'.") - - if not isinstance(corr, (int, float)) or not (-1 < corr < 1): - raise ValueError("Argument 'corr' must be a numeric value between -1 and 1.") - - if not isinstance(dist_par, list) or len(dist_par) == 0: - raise ValueError("Argument 'dist_par' must be a non-empty list of positive values.") - - if any(p <= 0 for p in dist_par): - raise ValueError("All elements in 'dist_par' must be greater than 0.") - - if dist == "exponential" and len(dist_par) != 2: - raise ValueError("Exponential distribution requires exactly 2 positive parameters.") - - if dist == "weibull" and len(dist_par) != 4: - raise ValueError("Weibull distribution requires exactly 4 positive parameters.") - - -def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par): - if not isinstance(n, int) or n <= 0: - raise ValueError("n must be a positive integer") - - if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): - raise ValueError("beta must be a list of numbers") - - if not isinstance(sigma, (int, float)) or sigma <= 0: - raise ValueError("sigma must be a positive number") - - if model_cens not in ("uniform", "exponential"): - raise ValueError("model_cens must be 'uniform' or 'exponential'") - - if not isinstance(cens_par, (int, float)) or cens_par <= 0: - raise ValueError("cens_par must be a positive number") +from .validation import * # noqa: F401,F403 diff --git a/gen_surv/validation.py b/gen_surv/validation.py new file mode 100644 index 0000000..11e07ef --- /dev/null +++ b/gen_surv/validation.py @@ -0,0 +1,513 @@ +"""Input validation utilities. + +This module unifies the low-level validation helpers and the higher-level +checks used by the data generators. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from numbers import Integral, Real +from typing import Any, Iterable + +import numpy as np +from numpy.typing import NDArray + + +class ValidationError(ValueError): + """Base class for input validation errors.""" + + +class PositiveIntegerError(ValidationError): + """Raised when a value expected to be a positive integer is invalid.""" + + def __init__(self, name: str, value: Any) -> None: + super().__init__( + f"Argument '{name}' must be a positive integer; got {value!r} of type {type(value).__name__}. " + "Please provide a whole number greater than 0." + ) + + +class PositiveValueError(ValidationError): + """Raised when a value expected to be positive is invalid.""" + + def __init__(self, name: str, value: Any) -> None: + super().__init__( + f"Argument '{name}' must be greater than 0; got {value!r} of type {type(value).__name__}. " + "Try a positive number such as 1.0." + ) + + +class ChoiceError(ValidationError): + """Raised when a value is not among an allowed set of choices.""" + + def __init__(self, name: str, value: Any, choices: Iterable[str]) -> None: + choices_str = "', '".join(sorted(choices)) + super().__init__( + f"Argument '{name}' must be one of '{choices_str}'; got {value!r} of type {type(value).__name__}. " + "Choose a valid option." + ) + + +class LengthError(ValidationError): + """Raised when a sequence does not have the expected length.""" + + def __init__(self, name: str, actual: int, expected: int) -> None: + super().__init__( + f"Argument '{name}' must be a sequence of length {expected}; got length {actual}. " + "Adjust the number of elements." + ) + + +class NumericSequenceError(ValidationError): + """Raised when a sequence contains non-numeric elements.""" + + def __init__(self, name: str, value: Any, index: int | None = None) -> None: + if index is None: + super().__init__( + f"All elements in '{name}' must be numeric; got {value!r}. " + "Convert or remove non-numeric values." + ) + else: + super().__init__( + f"All elements in '{name}' must be numeric; found {value!r} of type {type(value).__name__} at index {index}. " + "Replace or remove this entry." + ) + + +class PositiveSequenceError(ValidationError): + """Raised when a sequence contains non-positive elements.""" + + def __init__(self, name: str, value: Any, index: int) -> None: + super().__init__( + f"All elements in '{name}' must be greater than 0; found {value!r} at index {index}. " + "Use positive numbers only." + ) + + +class ListOfListsError(ValidationError): + """Raised when a value is not a list of lists.""" + + def __init__(self, name: str, value: Any) -> None: + super().__init__( + f"Argument '{name}' must be a list of lists; got {value!r} of type {type(value).__name__}. " + "Wrap items in a list." + ) + + +class ParameterError(ValidationError): + """Raised when a parameter falls outside its allowed range.""" + + def __init__(self, name: str, value: Any, constraint: str) -> None: + super().__init__( + f"Invalid value for '{name}': {value!r} (type {type(value).__name__}). {constraint}. " + "Check and adjust this parameter." + ) + + +_ALLOWED_CENSORING = {"uniform", "exponential"} + + +def ensure_positive_int(value: int, name: str) -> None: + """Ensure ``value`` is a positive integer.""" + if not isinstance(value, Integral) or isinstance(value, bool) or value <= 0: + raise PositiveIntegerError(name, value) + + +def ensure_positive(value: float | int, name: str) -> None: + """Ensure ``value`` is a positive number.""" + if not isinstance(value, Real) or isinstance(value, bool) or value <= 0: + raise PositiveValueError(name, value) + + +def ensure_probability(value: float | int, name: str) -> None: + """Ensure ``value`` lies in the closed interval [0, 1].""" + if ( + not isinstance(value, Real) + or isinstance(value, bool) + or not (0 <= float(value) <= 1) + ): + raise ParameterError(name, value, "must be between 0 and 1") + + +def ensure_in_choices(value: str, name: str, choices: Iterable[str]) -> None: + """Ensure ``value`` is one of the allowed options. + + Parameters + ---------- + value: + Value provided by the user. + name: + Name of the argument being validated. Used in error messages. + choices: + Iterable of valid string options. + + Raises + ------ + ChoiceError + If ``value`` is not present in ``choices``. + """ + if value not in choices: + raise ChoiceError(name, value, choices) + + +def ensure_sequence_length(seq: Sequence[Any], length: int, name: str) -> None: + """Ensure a sequence has an expected number of elements. + + Parameters + ---------- + seq: + Sequence-like object (e.g., ``list`` or ``tuple``). + length: + Required number of elements in ``seq``. + name: + Parameter name for error reporting. + + Raises + ------ + LengthError + If ``seq`` does not contain exactly ``length`` elements. + """ + if len(seq) != length: + raise LengthError(name, len(seq), length) + + +def _to_float_array(seq: Sequence[Any], name: str) -> NDArray[np.float64]: + """Convert ``seq`` to a NumPy float64 array or raise an error.""" + try: + arr = np.asarray(seq, dtype=float) + except (TypeError, ValueError) as exc: + for idx, val in enumerate(seq): + if isinstance(val, (bool, np.bool_)) or not isinstance(val, (int, float)): + raise NumericSequenceError(name, val, idx) from exc + raise NumericSequenceError(name, seq) from exc + + for idx, val in enumerate(seq): + if isinstance(val, (bool, np.bool_)): + raise NumericSequenceError(name, val, idx) + + return arr + + +def ensure_numeric_sequence(seq: Sequence[Any], name: str) -> None: + """Validate that a sequence consists solely of numbers. + + Parameters + ---------- + seq: + Sequence whose elements should all be ``int`` or ``float``. + name: + Parameter name for error reporting. + + Raises + ------ + NumericSequenceError + If any element cannot be interpreted as a numeric value. + """ + _to_float_array(seq, name) + + +def ensure_positive_sequence(seq: Sequence[float], name: str) -> None: + """Validate that a sequence contains only positive numbers. + + Parameters + ---------- + seq: + Sequence of numeric values. + name: + Parameter name for error reporting. + + Raises + ------ + PositiveSequenceError + If any element is less than or equal to zero. The offending value and + its index are reported in the error message. + """ + arr = _to_float_array(seq, name) + nonpos = np.where((arr <= 0) | ~np.isfinite(arr))[0] + if nonpos.size: + idx = int(nonpos[0]) + raise PositiveSequenceError(name, seq[idx], idx) + + +def ensure_censoring_model(model_cens: str) -> None: + """Validate that the censoring model is supported. + + Parameters + ---------- + model_cens: + Censoring model name provided by the user. + + Raises + ------ + ChoiceError + If ``model_cens`` is not one of ``"uniform"`` or ``"exponential"``. + """ + ensure_in_choices(model_cens, "model_cens", _ALLOWED_CENSORING) + + +# Generator-specific validation helpers + +_BETA_LEN = 3 +_CMM_RATE_LEN = 6 +_THMM_RATE_LEN = 3 +_WEIBULL_DIST_PAR_LEN = 4 +_EXP_DIST_PAR_LEN = 2 + + +def _validate_base(n: int, model_cens: str, cens_par: float) -> None: + """Common checks for sample size and censoring model.""" + ensure_positive_int(n, "n") + ensure_censoring_model(model_cens) + ensure_positive(cens_par, "cens_par") + + +def _validate_beta(beta: Sequence[float]) -> None: + """Ensure beta is a numeric sequence of length three.""" + ensure_sequence_length(beta, _BETA_LEN, "beta") + ensure_numeric_sequence(beta, "beta") + + +def _validate_aft_common( + n: int, beta: Sequence[float], model_cens: str, cens_par: float +) -> None: + """Shared validation logic for AFT generators.""" + _validate_base(n, model_cens, cens_par) + ensure_numeric_sequence(beta, "beta") + + +def _validate_covariate_inputs( + n: int, + n_covariates: int | None, + model_cens: str, + cens_par: float, + covariate_dist: str, + max_time: float | None = None, +) -> None: + """Common checks for generators with covariates. + + Parameters + ---------- + n: + Number of samples to generate. + n_covariates: + Expected number of covariates or ``None`` to skip the check. + model_cens: + Name of the censoring model. + cens_par: + Parameter for the censoring model. + covariate_dist: + Name of the covariate distribution. + max_time: + Optional maximum follow-up time. If provided, must be positive. + """ + _validate_base(n, model_cens, cens_par) + if n_covariates is not None: + ensure_positive_int(n_covariates, "n_covariates") + if max_time is not None: + ensure_positive(max_time, "max_time") + ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"}) + + +def validate_gen_cphm_inputs( + n: int, model_cens: str, cens_par: float, covariate_range: float +) -> None: + """Validate input parameters for CPHM data generation.""" + _validate_base(n, model_cens, cens_par) + ensure_positive(covariate_range, "covariate_range") + + +def validate_gen_cmm_inputs( + n: int, + model_cens: str, + cens_par: float, + beta: Sequence[float], + covariate_range: float, + rate: Sequence[float], +) -> None: + """Validate inputs for generating CMM (Continuous-Time Markov Model) data.""" + _validate_base(n, model_cens, cens_par) + _validate_beta(beta) + ensure_positive(covariate_range, "covariate_range") + ensure_sequence_length(rate, _CMM_RATE_LEN, "rate") + + +def validate_gen_tdcm_inputs( + n: int, + dist: str, + corr: float, + dist_par: Sequence[float], + model_cens: str, + cens_par: float, + beta: Sequence[float], + lam: float, +) -> None: + """Validate inputs for generating TDCM (Time-Dependent Covariate Model) data.""" + _validate_base(n, model_cens, cens_par) + ensure_in_choices(dist, "dist", {"weibull", "exponential"}) + + if dist == "weibull": + if not (0 < corr <= 1): + raise ParameterError("corr", corr, "with dist='weibull' must be in (0,1]") + ensure_sequence_length(dist_par, _WEIBULL_DIST_PAR_LEN, "dist_par") + ensure_positive_sequence(dist_par, "dist_par") + + if dist == "exponential": + if not (-1 <= corr <= 1): + raise ParameterError( + "corr", corr, "with dist='exponential' must be in [-1,1]" + ) + ensure_sequence_length(dist_par, _EXP_DIST_PAR_LEN, "dist_par") + ensure_positive_sequence(dist_par, "dist_par") + + _validate_beta(beta) + ensure_positive(lam, "lambda") + + +def validate_gen_thmm_inputs( + n: int, + model_cens: str, + cens_par: float, + beta: Sequence[float], + covariate_range: float, + rate: Sequence[float], +) -> None: + """Validate inputs for generating THMM (Time-Homogeneous Markov Model) data.""" + _validate_base(n, model_cens, cens_par) + _validate_beta(beta) + ensure_positive(covariate_range, "covariate_range") + ensure_sequence_length(rate, _THMM_RATE_LEN, "rate") + + +def validate_dg_biv_inputs( + n: int, dist: str, corr: float, dist_par: Sequence[float] +) -> None: + """Validate inputs for the :func:`sample_bivariate_distribution` helper.""" + ensure_positive_int(n, "n") + ensure_in_choices(dist, "dist", {"weibull", "exponential"}) + + if not isinstance(corr, (int, float)) or not (-1 < corr < 1): + raise ParameterError("corr", corr, "must be a numeric value between -1 and 1") + + ensure_positive_sequence(dist_par, "dist_par") + if dist == "exponential": + ensure_sequence_length(dist_par, _EXP_DIST_PAR_LEN, "dist_par") + if dist == "weibull": + ensure_sequence_length(dist_par, _WEIBULL_DIST_PAR_LEN, "dist_par") + + +def validate_gen_aft_log_normal_inputs( + n: int, + beta: Sequence[float], + sigma: float, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for the log-normal AFT generator.""" + _validate_aft_common(n, beta, model_cens, cens_par) + ensure_positive(sigma, "sigma") + + +def validate_gen_aft_weibull_inputs( + n: int, + beta: Sequence[float], + shape: float, + scale: float, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for the Weibull AFT generator.""" + _validate_aft_common(n, beta, model_cens, cens_par) + ensure_positive(shape, "shape") + ensure_positive(scale, "scale") + + +def validate_gen_aft_log_logistic_inputs( + n: int, + beta: Sequence[float], + shape: float, + scale: float, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for the log-logistic AFT generator.""" + _validate_aft_common(n, beta, model_cens, cens_par) + ensure_positive(shape, "shape") + ensure_positive(scale, "scale") + + +def validate_competing_risks_inputs( + n: int, + n_risks: int, + baseline_hazards: Sequence[float] | None, + betas: Sequence[Sequence[float]] | None, + covariate_dist: str, + max_time: float | None, + model_cens: str, + cens_par: float, +) -> None: + """Validate parameters for competing risks data generation.""" + _validate_covariate_inputs(n, None, model_cens, cens_par, covariate_dist, max_time) + ensure_positive_int(n_risks, "n_risks") + + if baseline_hazards is not None: + ensure_sequence_length(baseline_hazards, n_risks, "baseline_hazards") + ensure_positive_sequence(baseline_hazards, "baseline_hazards") + + if betas is not None: + if not isinstance(betas, list) or any(not isinstance(b, list) for b in betas): + raise ListOfListsError("betas", betas) + for b in betas: + ensure_numeric_sequence(b, "betas") + + +def validate_piecewise_params( + breakpoints: Sequence[float], hazard_rates: Sequence[float] +) -> None: + """Validate breakpoint and hazard rate sequences.""" + ensure_sequence_length(hazard_rates, len(breakpoints) + 1, "hazard_rates") + ensure_positive_sequence(breakpoints, "breakpoints") + ensure_positive_sequence(hazard_rates, "hazard_rates") + if np.any(np.diff(breakpoints) <= 0): + raise ParameterError( + "breakpoints", + breakpoints, + "must be a strictly increasing sequence. Sort the list and remove duplicates.", + ) + + +def validate_gen_piecewise_inputs( + n: int, + breakpoints: Sequence[float], + hazard_rates: Sequence[float], + n_covariates: int, + model_cens: str, + cens_par: float, + covariate_dist: str, +) -> None: + """Validate parameters for :func:`gen_piecewise_exponential`.""" + _validate_covariate_inputs(n, n_covariates, model_cens, cens_par, covariate_dist) + validate_piecewise_params(breakpoints, hazard_rates) + + +def validate_gen_mixture_inputs( + n: int, + cure_fraction: float, + baseline_hazard: float, + n_covariates: int, + model_cens: str, + cens_par: float, + max_time: float | None, + covariate_dist: str, +) -> None: + """Validate parameters for :func:`gen_mixture_cure`.""" + _validate_covariate_inputs( + n, n_covariates, model_cens, cens_par, covariate_dist, max_time + ) + ensure_positive(baseline_hazard, "baseline_hazard") + if not 0 < cure_fraction < 1: + raise ParameterError( + "cure_fraction", + cure_fraction, + "must be between 0 and 1 (exclusive). Try a value like 0.5", + ) diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py new file mode 100644 index 0000000..3a80c11 --- /dev/null +++ b/gen_surv/visualization.py @@ -0,0 +1,375 @@ +""" +Visualization utilities for survival data. + +This module provides functions to visualize survival data generated by +gen_surv, +including Kaplan-Meier survival curves and other commonly used plots in +survival analysis. +""" + +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.axes import Axes +from matplotlib.figure import Figure + + +def plot_survival_curve( + data: pd.DataFrame, + time_col: str = "time", + status_col: str = "status", + group_col: str | None = None, + confidence_intervals: bool = True, + title: str = "Kaplan-Meier Survival Curve", + figsize: tuple[float, float] = (10, 6), + ci_alpha: float = 0.2, +) -> tuple[Figure, Axes]: + """ + Plot Kaplan-Meier survival curves from simulated data. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the survival data. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + group_col : str, optional + Name of the column to use for stratification (creates separate curves). + confidence_intervals : bool, default=True + Whether to display confidence intervals around the survival curves. + title : str, default="Kaplan-Meier Survival Curve" + Plot title. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + ci_alpha : float, default=0.2 + Transparency level for confidence interval bands. + + Returns + ------- + fig : Figure + Matplotlib figure object. + ax : Axes + Matplotlib axes object. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import plot_survival_curve + >>> + >>> # Generate data + >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> + >>> # Create a categorical group based on covariate + >>> df["group"] = pd.cut(df["covariate"], bins=2, labels=["Low", "High"]) + >>> + >>> # Plot survival curves by group + >>> fig, ax = plot_survival_curve(df, group_col="group") + >>> plt.show() + """ + # Import lifelines here to avoid making it a hard dependency + try: + from lifelines import KaplanMeierFitter + from lifelines.plotting import add_at_risk_counts + except ImportError as exc: + raise ImportError( + "This function requires the lifelines package. " + "Install it with: pip install lifelines" + ) from exc + + fig, ax = plt.subplots(figsize=figsize) + + # Create separate KM curves for each group (if specified) + if group_col is not None: + groups = data[group_col].unique() + cmap = plt.get_cmap("tab10") + colors = [cmap(i) for i in range(len(groups))] + + for i, group in enumerate(groups): + mask = data[group_col] == group + group_data = data[mask] + + kmf = KaplanMeierFitter() + kmf.fit( + group_data[time_col], + group_data[status_col], + label=f"{group_col}={group}", + ) + + kmf.plot_survival_function( + ax=ax, ci_show=confidence_intervals, color=colors[i], ci_alpha=ci_alpha + ) + + # Add at-risk counts below the plot + add_at_risk_counts(kmf, ax=ax) + else: + # Single KM curve for all data + kmf = KaplanMeierFitter() + kmf.fit(data[time_col], data[status_col]) + + kmf.plot_survival_function( + ax=ax, ci_show=confidence_intervals, ci_alpha=ci_alpha + ) + + # Add at-risk counts below the plot + add_at_risk_counts(kmf, ax=ax) + + # Customize plot appearance + ax.set_title(title) + ax.set_xlabel("Time") + ax.set_ylabel("Survival Probability") + ax.grid(alpha=0.3) + ax.set_ylim(0, 1.05) + + plt.tight_layout() + return fig, ax + + +def plot_hazard_comparison( + models: dict[str, pd.DataFrame], + time_col: str = "time", + status_col: str = "status", + title: str = "Hazard Function Comparison", + figsize: tuple[float, float] = (10, 6), + bandwidth: float = 0.5, +) -> tuple[Figure, Axes]: + """ + Compare hazard functions from multiple generated datasets. + + Parameters + ---------- + models : dict + Dictionary mapping model names to their respective DataFrames. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + title : str, default="Hazard Function Comparison" + Plot title. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + bandwidth : float, default=0.5 + Bandwidth parameter for kernel density estimation of the hazard function. + + Returns + ------- + fig : Figure + Matplotlib figure object. + ax : Axes + Matplotlib axes object. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import plot_hazard_comparison + >>> + >>> # Generate data from multiple models + >>> models = { + >>> "CPHM": generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0), + >>> "AFT Weibull": generate(model="aft_weibull", n=100, beta=[0.5], shape=1.5, scale=2.0, + >>> model_cens="uniform", cens_par=1.0) + >>> } + >>> + >>> # Compare hazard functions + >>> fig, ax = plot_hazard_comparison(models) + >>> plt.show() + """ + # Import lifelines here to avoid making it a hard dependency + try: + from lifelines import NelsonAalenFitter + except ImportError as exc: + raise ImportError( + "This function requires the lifelines package. " + "Install it with: pip install lifelines" + ) from exc + + fig, ax = plt.subplots(figsize=figsize) + + for model_name, df in models.items(): + naf = NelsonAalenFitter() + naf.fit(df[time_col], df[status_col]) + + # Get smoothed hazard estimate + hazard = naf.smoothed_hazard_(bandwidth=bandwidth) + + # Plot hazard function + ax.plot(hazard.index, hazard.values, label=model_name, alpha=0.8) + + # Customize plot appearance + ax.set_title(title) + ax.set_xlabel("Time") + ax.set_ylabel("Hazard Rate") + ax.grid(alpha=0.3) + ax.legend() + + plt.tight_layout() + return fig, ax + + +def plot_covariate_effect( + data: pd.DataFrame, + covariate_col: str, + time_col: str = "time", + status_col: str = "status", + n_groups: int = 3, + title: str = "Effect of Covariate on Survival", + figsize: tuple[float, float] = (10, 6), + ci_alpha: float = 0.2, +) -> tuple[Figure, Axes]: + """ + Visualize the effect of a continuous covariate on survival by discretizing + it. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the survival data. + covariate_col : str + Name of the covariate column to visualize. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + n_groups : int, default=3 + Number of groups to divide the covariate into (e.g., 3 for tertiles). + title : str, default="Effect of Covariate on Survival" + Plot title. + figsize : tuple, default=(10, 6) + Figure size (width, height) in inches. + ci_alpha : float, default=0.2 + Transparency level for confidence interval bands. + + Returns + ------- + fig : Figure + Matplotlib figure object. + ax : Axes + Matplotlib axes object. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import plot_covariate_effect + >>> + >>> # Generate data with a continuous covariate + >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> + >>> # Visualize the effect of the covariate on survival + >>> fig, ax = plot_covariate_effect(df, covariate_col="covariate", n_groups=3) + >>> plt.show() + """ + # Add a categorical version of the covariate + group_labels = [f"Q{i + 1}" for i in range(n_groups)] + data = data.copy() + data["_group"] = pd.qcut(data[covariate_col], q=n_groups, labels=group_labels) + + # Get the median value of each group for the legend + group_medians = data.groupby("_group")[covariate_col].median() + + # Create more informative labels + label_map = { + group: f"{group} ({covariate_col}β‰ˆ{median:.2f})" + for group, median in group_medians.items() + } + + data["_label"] = data["_group"].map(label_map) + + # Create the plot + fig, ax = plot_survival_curve( + data=data, + time_col=time_col, + status_col=status_col, + group_col="_label", + confidence_intervals=True, + title=title, + figsize=figsize, + ci_alpha=ci_alpha, + ) + + return fig, ax + + +def describe_survival( + data: pd.DataFrame, time_col: str = "time", status_col: str = "status" +) -> pd.DataFrame: + """ + Generate a summary of survival data including median survival time, + event counts, and other descriptive statistics. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the survival data. + time_col : str, default="time" + Name of the column containing event/censoring times. + status_col : str, default="status" + Name of the column containing event indicators (1=event, 0=censored). + + Returns + ------- + pd.DataFrame + Summary statistics dataframe. + + Examples + -------- + >>> from gen_surv import generate + >>> from gen_surv.visualization import describe_survival + >>> + >>> # Generate data + >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> + >>> # Get survival summary + >>> summary = describe_survival(df) + >>> print(summary) + """ + # Import lifelines here to avoid making it a hard dependency + try: + from lifelines import KaplanMeierFitter + except ImportError as exc: + raise ImportError( + "This function requires the lifelines package. " + "Install it with: pip install lifelines" + ) from exc + + n_total = len(data) + n_events = data[status_col].sum() + n_censored = n_total - n_events + event_rate = n_events / n_total + + # Calculate median and other percentiles + kmf = KaplanMeierFitter() + kmf.fit(data[time_col], data[status_col]) + median = kmf.median_survival_time_ + + # Time ranges + time_min = data[time_col].min() + time_max = data[time_col].max() + time_mean = data[time_col].mean() + + # Create summary DataFrame + summary = pd.DataFrame( + { + "Metric": [ + "Total Observations", + "Number of Events", + "Number Censored", + "Event Rate", + "Median Survival Time", + "Min Time", + "Max Time", + "Mean Time", + ], + "Value": [ + n_total, + n_events, + n_censored, + f"{event_rate:.2%}", + f"{median:.4f}", + f"{time_min:.4f}", + f"{time_max:.4f}", + f"{time_mean:.4f}", + ], + } + ) + + return summary diff --git a/pyproject.toml b/pyproject.toml index 5048a06..1421ff0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gen_surv" -version = "1.0.8" +version = "1.0.9" description = "A Python package for simulating survival data, inspired by the R package genSurv" authors = ["Diogo Ribeiro "] license = "MIT" @@ -8,27 +8,72 @@ readme = "README.md" packages = [{ include = "gen_surv" }] homepage = "https://github.com/DiogoRibeiro7/genSurvPy" repository = "https://github.com/DiogoRibeiro7/genSurvPy" -documentation = "https://gensurvpy.readthedocs.io/en/stable/" +documentation = "https://gensurvpy.readthedocs.io/en/latest/" +keywords = ["survival-analysis", "simulation", "cox-model", "markov-model", "time-dependent", "statistics"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "Intended Audience :: Healthcare Industry", + "Topic :: Scientific/Engineering :: Medical Science Apps.", + "Topic :: Scientific/Engineering :: Mathematics", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: MIT License", +] [tool.poetry.dependencies] -python = "^3.9" +python = ">=3.10,<3.13" numpy = "^1.26" pandas = "^2.2.3" -pytest-cov = "^6.1.1" -invoke = "^2.2.0" typer = "^0.12.3" -tomli = "^2.2.1" +matplotlib = "~3.8" +lifelines = "^0.30" +pyarrow = "^14" +pyreadr = "^0.5" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" +pytest-cov = "^6.1.1" python-semantic-release = "^9.21.0" mypy = "^1.15.0" invoke = "^2.2.0" hypothesis = "^6.98" tomli = "^2.2.1" +black = "^24.1.0" +isort = "^5.13.2" +flake8 = "^6.1.0" +scikit-survival = "^0.24.1" +pre-commit = "^3.8" + +[tool.poetry.extras] +dev = [ + "pytest", + "pytest-cov", + "python-semantic-release", + "mypy", + "invoke", + "hypothesis", + "tomli", + "black", + "isort", + "flake8", + "scikit-survival", + "pre-commit", +] [tool.poetry.group.docs.dependencies] -myst-parser = "<4.0.0" +sphinx = ">=6.0" +sphinx-rtd-theme = "^1.3.0" +myst-parser = ">=1.0.0,<4.0.0" +sphinx-autodoc-typehints = "^1.24.0" +sphinx-copybutton = "^0.5.2" +sphinx-design = "^0.5.0" +linkify-it-py = ">=2.0" + +[tool.poetry.scripts] +gen_surv = "gen_surv.cli:app" [tool.semantic_release] version_source = "tag" @@ -39,6 +84,48 @@ upload_to_repository = false branch = "main" build_command = "" +[tool.black] +line-length = 88 +target-version = ['py310'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +line_length = 88 + +[tool.flake8] +max-line-length = 88 +extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "E402", "E302", "E305"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true + +[[tool.mypy.overrides]] +module = [ + "typer", + "typer.models", + "matplotlib", + "matplotlib.*", + "lifelines", + "lifelines.*", + "sklearn", + "sklearn.*", + "numpy", + "numpy.*", + "pandas", + "pandas.*", + "scipy", + "scipy.*", + "pyreadr", + "sksurv", + "sksurv.*", +] +ignore_missing_imports = true + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/scripts/check_version_match.py b/scripts/check_version_match.py index 9c3b7dd..a7847c3 100755 --- a/scripts/check_version_match.py +++ b/scripts/check_version_match.py @@ -1,21 +1,34 @@ #!/usr/bin/env python3 -"""Check that pyproject version matches the latest git tag. Optionally fix it by tagging.""" -from pathlib import Path +"""Keep ``pyproject.toml`` in sync with the latest git tag. + +When run with ``--fix`` the script will create a git tag from the version +declared in ``pyproject.toml``. Supplying ``--write`` updates the +``pyproject.toml`` version to match the latest tag. Using both flags ensures +that whichever side is ahead becomes the single source of truth. +""" +from __future__ import annotations + +import argparse +import re import subprocess import sys +from pathlib import Path +from typing import Any, cast -if sys.version_info >= (3, 11): +if sys.version_info >= (3, 11): # pragma: no cover - stdlib alias import tomllib as tomli -else: +else: # pragma: no cover - python <3.11 fallback import tomli ROOT = Path(__file__).resolve().parents[1] + def pyproject_version() -> str: pyproject_path = ROOT / "pyproject.toml" with pyproject_path.open("rb") as f: - data = tomli.load(f) - return data["tool"]["poetry"]["version"] + data: Any = tomli.load(f) + return cast(str, data["tool"]["poetry"]["version"]) + def latest_tag() -> str: try: @@ -26,37 +39,72 @@ def latest_tag() -> str: except subprocess.CalledProcessError: return "" + def create_tag(version: str) -> None: print(f"Tagging repository with version: v{version}") subprocess.run(["git", "tag", f"v{version}"], cwd=ROOT, check=True) subprocess.run(["git", "push", "origin", f"v{version}"], cwd=ROOT, check=True) print(f"βœ… Git tag v{version} created and pushed.") + +def write_version(version: str) -> None: + """Update ``pyproject.toml`` with *version* and push the change.""" + pyproject_path = ROOT / "pyproject.toml" + content = pyproject_path.read_text() + updated = re.sub( + r'^version = "[^"]+"', f'version = "{version}"', content, flags=re.MULTILINE + ) + pyproject_path.write_text(updated) + subprocess.run(["git", "add", str(pyproject_path)], cwd=ROOT, check=True) + subprocess.run( + ["git", "commit", "-m", f"chore: bump version to {version}"], + cwd=ROOT, + check=True, + ) + subprocess.run(["git", "push"], cwd=ROOT, check=True) + print(f"βœ… pyproject.toml updated to {version} and pushed.") + + +def _split(v: str) -> tuple[int, ...]: + return tuple(int(part) for part in v.split(".")) + + def main() -> int: - fix = "--fix" in sys.argv + parser = argparse.ArgumentParser(description="Sync git tags and pyproject version") + parser.add_argument( + "--fix", action="store_true", help="Tag repo from pyproject version" + ) + parser.add_argument( + "--write", action="store_true", help="Update pyproject version from latest tag" + ) + args = parser.parse_args() + version = pyproject_version() tag = latest_tag() if not tag: print("⚠️ No git tag found.", file=sys.stderr) - if fix: + if args.fix: create_tag(version) return 0 - else: - return 1 + return 1 if version != tag: print( f"❌ Version mismatch: pyproject.toml has {version} but latest tag is {tag}", file=sys.stderr, ) - if fix: + if args.fix and _split(version) > _split(tag): create_tag(version) return 0 + if args.write and _split(tag) > _split(version): + write_version(tag) + return 0 return 1 print(f"βœ”οΈ Version matches latest tag: {version}") return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/tasks.py b/tasks.py index d37bc3a..1a93900 100644 --- a/tasks.py +++ b/tasks.py @@ -1,9 +1,6 @@ -from invoke.tasks import task -from invoke import Context, task -from typing import Any import shlex - +from invoke import Context, task @task @@ -25,13 +22,10 @@ def test(c: Context) -> None: # Build the command string. You can adjust '--cov=gen_surv' if you # need to cover a different package or add extra pytest flags. command = ( - "poetry run pytest " - "--cov=gen_surv " - "--cov-report=term " - "--cov-report=xml" + "poetry run pytest " "--cov=gen_surv " "--cov-report=term " "--cov-report=xml" ) - # Run pytest. + # Run pytest. # - warn=True: capture non-zero exit codes without aborting Invoke. # - pty=False: pytest doesn’t require an interactive TTY here. result = c.run(command, warn=True, pty=False) @@ -41,9 +35,15 @@ def test(c: Context) -> None: print("βœ”οΈ All tests passed.") else: print("❌ Some tests failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -76,6 +76,7 @@ def checkversion(c: Context) -> None: print("❌ Version mismatch detected.") print(result.stderr) + @task def docs(c: Context) -> None: """Build the Sphinx documentation. @@ -103,9 +104,15 @@ def docs(c: Context) -> None: print("βœ”οΈ Documentation built successfully.") else: print("❌ Documentation build failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -138,9 +145,15 @@ def stubs(c: Context) -> None: print("βœ”οΈ Type stubs generated successfully in 'stubs/'.") else: print("❌ Stub generation failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -170,16 +183,25 @@ def build(c: Context) -> None: # Report the result of the build process. if result is not None and getattr(result, "ok", False): - print("βœ”οΈ Build completed successfully. Artifacts are in the 'dist/' directory.") + print( + "βœ”οΈ Build completed successfully. Artifacts are in the 'dist/' directory." + ) else: print("❌ Build failed.") - exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" + exit_code = ( + result.exited + if result is not None and hasattr(result, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None + stderr_output = ( + result.stderr if result is not None and hasattr(result, "stderr") else None + ) if stderr_output: print("Error output:") print(stderr_output) + @task def publish(c: Context) -> None: """Build and upload the package to PyPI. @@ -210,6 +232,7 @@ def publish(c: Context) -> None: else: print("No stderr output captured.") + @task def clean(c: Context) -> None: """Remove build artifacts and caches. @@ -254,7 +277,8 @@ def clean(c: Context) -> None: if result.stderr: print("Error output:") print(result.stderr) - + + @task def gitpush(c: Context) -> None: """Commit and push all staged changes. @@ -273,9 +297,17 @@ def gitpush(c: Context) -> None: result_add = c.run("git add .", warn=True, pty=False) if result_add is None or not getattr(result_add, "ok", False): print("❌ Failed to stage changes (git add).") - exit_code = result_add.exited if result_add is not None and hasattr(result_add, "exited") else "Unknown" + exit_code = ( + result_add.exited + if result_add is not None and hasattr(result_add, "exited") + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = result_add.stderr if result_add is not None and hasattr(result_add, "stderr") else None + stderr_output = ( + result_add.stderr + if result_add is not None and hasattr(result_add, "stderr") + else None + ) if stderr_output: print("Error output:") print(stderr_output) @@ -313,9 +345,17 @@ def gitpush(c: Context) -> None: print("βœ”οΈ Changes pushed successfully.") else: print("❌ Push failed.") - exit_code = getattr(result_push, "exited", "Unknown") if result_push is not None else "Unknown" + exit_code = ( + getattr(result_push, "exited", "Unknown") + if result_push is not None + else "Unknown" + ) print(f"Exit code: {exit_code}") - stderr_output = getattr(result_push, "stderr", None) if result_push is not None else None + stderr_output = ( + getattr(result_push, "stderr", None) + if result_push is not None + else None + ) if stderr_output: print("Error output:") print(stderr_output) diff --git a/tests/test_aft.py b/tests/test_aft.py index 2688144..1484432 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -1,14 +1,157 @@ +""" +Tests for Accelerated Failure Time (AFT) models. +""" + import pandas as pd -from gen_surv.aft import gen_aft_log_normal +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull +from gen_surv.validation import ChoiceError, PositiveValueError + + +def test_gen_aft_log_logistic_runs(): + """Test that the Log-Logistic AFT generator runs without errors.""" + df = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1}) + + +def test_gen_aft_log_logistic_invalid_shape(): + """Test that the Log-Logistic AFT generator raises error + for invalid shape.""" + with pytest.raises(PositiveValueError): + gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=-1.0, # Invalid negative shape + scale=2.0, + model_cens="uniform", + cens_par=5.0, + ) + + +def test_gen_aft_log_logistic_invalid_scale(): + """Test that the Log-Logistic AFT generator raises error + for invalid scale.""" + with pytest.raises(PositiveValueError): + gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=0.0, # Invalid zero scale + model_cens="uniform", + cens_par=5.0, + ) + + +@given( + n=st.integers(min_value=1, max_value=20), + shape=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + scale=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + cens_par=st.floats( + min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False + ), + seed=st.integers(min_value=0, max_value=1000), +) +def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed): + """Property-based test for the Log-Logistic AFT generator.""" + df = gen_aft_log_logistic( + n=n, + beta=[0.5, -0.2], + shape=shape, + scale=scale, + model_cens="uniform", + cens_par=cens_par, + seed=seed, + ) + assert df.shape[0] == n + assert set(df["status"].unique()).issubset({0, 1}) + assert (df["time"] >= 0).all() + assert df.filter(regex="^X[0-9]+$").shape[1] == 2 + + +def test_gen_aft_log_logistic_reproducibility(): + """Test that the Log-Logistic AFT generator is reproducible + with the same seed.""" + df1 = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ) + + df2 = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ) + + pd.testing.assert_frame_equal(df1, df2) + + df3 = gen_aft_log_logistic( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=43, # Different seed + ) + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df3) + def test_gen_aft_log_normal_runs(): + """Test that the log-normal AFT generator runs without errors.""" df = gen_aft_log_normal( + n=10, beta=[0.5, -0.2], sigma=1.0, model_cens="uniform", cens_par=5.0, seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1}) + + +def test_gen_aft_weibull_runs(): + """Test that the Weibull AFT generator runs without errors.""" + df = gen_aft_weibull( n=10, beta=[0.5, -0.2], - sigma=1.0, + shape=1.5, + scale=2.0, model_cens="uniform", cens_par=5.0, - seed=42 + seed=42, ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -16,4 +159,111 @@ def test_gen_aft_log_normal_runs(): assert "status" in df.columns assert "X0" in df.columns assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1}) \ No newline at end of file + assert set(df["status"].unique()).issubset({0, 1}) + + +def test_gen_aft_weibull_invalid_shape(): + """Test that the Weibull AFT generator raises error for invalid shape.""" + with pytest.raises(PositiveValueError): + gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=-1.0, # Invalid negative shape + scale=2.0, + model_cens="uniform", + cens_par=5.0, + ) + + +def test_gen_aft_weibull_invalid_scale(): + """Test that the Weibull AFT generator raises error for invalid scale.""" + with pytest.raises(PositiveValueError): + gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=0.0, # Invalid zero scale + model_cens="uniform", + cens_par=5.0, + ) + + +def test_gen_aft_weibull_invalid_cens_model(): + """Test that the Weibull AFT generator raises error for invalid censoring model.""" + with pytest.raises(ChoiceError): + gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="invalid", # Invalid censoring model + cens_par=5.0, + ) + + +@given( + n=st.integers(min_value=1, max_value=20), + shape=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + scale=st.floats( + min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False + ), + cens_par=st.floats( + min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False + ), + seed=st.integers(min_value=0, max_value=1000), +) +def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed): + """Property-based test for the Weibull AFT generator.""" + df = gen_aft_weibull( + n=n, + beta=[0.5, -0.2], + shape=shape, + scale=scale, + model_cens="uniform", + cens_par=cens_par, + seed=seed, + ) + assert df.shape[0] == n + assert set(df["status"].unique()).issubset({0, 1}) + assert (df["time"] >= 0).all() + assert df.filter(regex="^X[0-9]+$").shape[1] == 2 + + +def test_gen_aft_weibull_reproducibility(): + """Test that the Weibull AFT generator is reproducible with the same seed.""" + df1 = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ) + + df2 = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=42, + ) + + pd.testing.assert_frame_equal(df1, df2) + + df3 = gen_aft_weibull( + n=10, + beta=[0.5, -0.2], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=5.0, + seed=43, # Different seed + ) + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df3) diff --git a/tests/test_aft_property.py b/tests/test_aft_property.py index 18a6412..73a4061 100644 --- a/tests/test_aft_property.py +++ b/tests/test_aft_property.py @@ -1,11 +1,18 @@ -from hypothesis import given, strategies as st +from hypothesis import given +from hypothesis import strategies as st + from gen_surv.aft import gen_aft_log_normal + @given( n=st.integers(min_value=1, max_value=20), - sigma=st.floats(min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False), - cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), - seed=st.integers(min_value=0, max_value=1000) + sigma=st.floats( + min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False + ), + cens_par=st.floats( + min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False + ), + seed=st.integers(min_value=0, max_value=1000), ) def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): df = gen_aft_log_normal( @@ -14,7 +21,7 @@ def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): sigma=sigma, model_cens="uniform", cens_par=cens_par, - seed=seed + seed=seed, ) assert df.shape[0] == n assert set(df["status"].unique()).issubset({0, 1}) diff --git a/tests/test_bivariate.py b/tests/test_bivariate.py index b011130..df0bd5d 100644 --- a/tests/test_bivariate.py +++ b/tests/test_bivariate.py @@ -1,7 +1,9 @@ import numpy as np -from gen_surv.bivariate import sample_bivariate_distribution import pytest +from gen_surv.bivariate import sample_bivariate_distribution +from gen_surv.validation import ChoiceError, LengthError + def test_sample_bivariate_exponential_shape(): """Exponential distribution should return an array of shape (n, 2).""" @@ -11,16 +13,18 @@ def test_sample_bivariate_exponential_shape(): def test_sample_bivariate_invalid_dist(): - """Unsupported distributions should raise ValueError.""" - with pytest.raises(ValueError): + """Unsupported distributions should raise ChoiceError.""" + with pytest.raises(ChoiceError): sample_bivariate_distribution(10, "invalid", 0.0, [1, 1]) + def test_sample_bivariate_exponential_param_length_error(): - """Exponential distribution with wrong param length should raise ValueError.""" - with pytest.raises(ValueError): + """Exponential distribution with wrong param length should raise LengthError.""" + with pytest.raises(LengthError): sample_bivariate_distribution(5, "exponential", 0.0, [1.0]) + def test_sample_bivariate_weibull_param_length_error(): - """Weibull distribution with wrong param length should raise ValueError.""" - with pytest.raises(ValueError): + """Weibull distribution with wrong param length should raise LengthError.""" + with pytest.raises(LengthError): sample_bivariate_distribution(5, "weibull", 0.0, [1.0, 1.0]) diff --git a/tests/test_censoring.py b/tests/test_censoring.py new file mode 100644 index 0000000..33ce3d5 --- /dev/null +++ b/tests/test_censoring.py @@ -0,0 +1,72 @@ +import numpy as np + +from gen_surv.censoring import ( + GammaCensoring, + LogNormalCensoring, + WeibullCensoring, + rexpocens, + rgammacens, + rlognormcens, + runifcens, + rweibcens, +) + + +def test_runifcens_range(): + times = runifcens(5, 2.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + assert np.all(times <= 2.0) + + +def test_rexpocens_nonnegative(): + times = rexpocens(5, 2.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + + +def test_rweibcens_nonnegative(): + times = rweibcens(5, 1.0, 1.5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + + +def test_rlognormcens_positive(): + times = rlognormcens(5, 0.0, 1.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) + + +def test_rgammacens_positive(): + times = rgammacens(5, 2.0, 1.0) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) + + +def test_weibull_censoring_class(): + model = WeibullCensoring(scale=1.0, shape=1.5) + times = model(5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times >= 0) + + +def test_lognormal_censoring_class(): + model = LogNormalCensoring(mean=0.0, sigma=1.0) + times = model(5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) + + +def test_gamma_censoring_class(): + model = GammaCensoring(shape=2.0, scale=1.0) + times = model(5) + assert isinstance(times, np.ndarray) + assert len(times) == 5 + assert np.all(times > 0) diff --git a/tests/test_cli.py b/tests/test_cli.py index d5fd8d7..6121a55 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,12 +1,23 @@ -import pandas as pd -from gen_surv.cli import dataset import runpy +import sys +from typing import Any + +import pandas as pd +import pytest +import typer + +from gen_surv.cli import dataset, visualize +from gen_surv.validation import ValidationError def test_cli_dataset_stdout(monkeypatch, capsys): - """Dataset command prints CSV to stdout when no output file is given.""" + """ + Test that the 'dataset' CLI command prints the generated CSV data to stdout when no output file is specified. + This test patches the 'generate' function to return a simple DataFrame, invokes the CLI command directly, + and asserts that the expected CSV header appears in the captured standard output. + """ - def fake_generate(model: str, n: int): + def fake_generate(**_: Any): return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) # Patch the generate function used in the CLI to avoid heavy computation. @@ -27,14 +38,15 @@ def fake_app(): # Patch the CLI app before the module is executed monkeypatch.setattr("gen_surv.cli.app", fake_app) - monkeypatch.setattr("sys.argv", ["gen_surv", "dataset", "cphm"]) + monkeypatch.setattr(sys, "argv", ["gen_surv", "dataset", "cphm"]) runpy.run_module("gen_surv.__main__", run_name="__main__") assert called + def test_cli_dataset_file_output(monkeypatch, tmp_path): """Dataset command writes CSV to file when output path is provided.""" - def fake_generate(model: str, n: int): + def fake_generate(**_: Any): return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) monkeypatch.setattr("gen_surv.cli.generate", fake_generate) @@ -43,3 +55,200 @@ def fake_generate(model: str, n: int): assert out_file.exists() content = out_file.read_text() assert "time,status,X0,X1" in content + + +def test_dataset_weibull_parameters(monkeypatch): + """Parameters for aft_weibull model are forwarded correctly.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [0]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None + ) + assert captured["model"] == "aft_weibull" + assert captured["beta"] == [0.1, 0.2] + assert captured["shape"] == 1.1 + assert captured["scale"] == 2.2 + + +def test_dataset_aft_ln(monkeypatch): + """aft_ln model should forward beta list and sigma.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset(model="aft_ln", n=1, beta=[0.3, 0.4], sigma=1.2, output=None) + assert captured["beta"] == [0.3, 0.4] + assert captured["sigma"] == 1.2 + + +def test_dataset_competing_risks(monkeypatch): + """competing_risks expands betas and passes hazards.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="competing_risks", + n=1, + n_risks=2, + baseline_hazards=[0.1, 0.2], + beta=0.5, + output=None, + ) + assert captured["n_risks"] == 2 + assert captured["baseline_hazards"] == [0.1, 0.2] + assert captured["betas"] == [0.5, 0.5] + + +def test_dataset_mixture_cure(monkeypatch): + """mixture_cure passes cure and baseline parameters.""" + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="mixture_cure", + n=1, + cure_fraction=0.2, + baseline_hazard=0.1, + beta=[0.4], + output=None, + ) + assert captured["cure_fraction"] == 0.2 + assert captured["baseline_hazard"] == 0.1 + assert captured["betas_survival"] == [0.4] + assert captured["betas_cure"] == [0.4] + + +def test_dataset_invalid_model(monkeypatch): + def fake_generate(**kwargs): + raise ValueError("bad model") + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + with pytest.raises(ValueError): + dataset(model="nope", n=1, output=None) + + +def test_cli_visualize_basic(monkeypatch, tmp_path): + csv = tmp_path / "data.csv" + pd.DataFrame({"time": [1, 2], "status": [1, 0]}).to_csv(csv, index=False) + + def fake_plot_survival_curve(**kwargs): + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot([0, 1], [1, 0]) + return fig, ax + + monkeypatch.setattr( + "gen_surv.visualization.plot_survival_curve", fake_plot_survival_curve + ) + + saved = [] + + def fake_savefig(path, *args, **kwargs): + saved.append(path) + + monkeypatch.setattr("matplotlib.pyplot.savefig", fake_savefig) + + visualize( + str(csv), + time_col="time", + status_col="status", + group_col=None, + output=str(tmp_path / "plot.png"), + ) + assert saved and saved[0].endswith("plot.png") + + +def test_dataset_aft_log_logistic(monkeypatch): + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="aft_log_logistic", + n=1, + beta=[0.1], + shape=1.2, + scale=2.3, + output=None, + ) + assert captured["model"] == "aft_log_logistic" + assert captured["beta"] == [0.1] + assert captured["shape"] == 1.2 + assert captured["scale"] == 2.3 + + +def test_dataset_competing_risks_weibull(monkeypatch): + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="competing_risks_weibull", + n=1, + n_risks=2, + shape_params=[0.7, 1.2], + scale_params=[2.0, 2.0], + beta=0.3, + output=None, + ) + assert captured["n_risks"] == 2 + assert captured["shape_params"] == [0.7, 1.2] + assert captured["scale_params"] == [2.0, 2.0] + assert captured["betas"] == [0.3, 0.3] + + +def test_dataset_piecewise(monkeypatch): + captured = {} + + def fake_generate(**kwargs): + captured.update(kwargs) + return pd.DataFrame({"time": [1], "status": [1]}) + + monkeypatch.setattr("gen_surv.cli.generate", fake_generate) + dataset( + model="piecewise_exponential", + n=1, + breakpoints=[1.0], + hazard_rates=[0.2, 0.3], + beta=[0.4], + output=None, + ) + assert captured["breakpoints"] == [1.0] + assert captured["hazard_rates"] == [0.2, 0.3] + assert captured["betas"] == [0.4] + + +def test_dataset_validation_error(monkeypatch, capsys): + def bad_generate(**kwargs: Any): + raise ValidationError("invalid n") + + monkeypatch.setattr("gen_surv.cli.generate", bad_generate) + + with pytest.raises(typer.Exit): + dataset(model="cphm", n=1, output=None) + + captured = capsys.readouterr() + assert "Input error: invalid n" in captured.out diff --git a/tests/test_cli_integration.py b/tests/test_cli_integration.py new file mode 100644 index 0000000..5012f1e --- /dev/null +++ b/tests/test_cli_integration.py @@ -0,0 +1,30 @@ +import pandas as pd +from typer.testing import CliRunner + +from gen_surv.cli import app + + +def test_dataset_cli_integration(tmp_path): + """Run dataset command end-to-end and verify CSV output.""" + runner = CliRunner() + out_file = tmp_path / "data.csv" + result = runner.invoke( + app, + [ + "dataset", + "cphm", + "--n", + "3", + "--beta", + "0.5", + "--covariate-range", + "1.0", + "-o", + str(out_file), + ], + ) + assert result.exit_code == 0 + assert out_file.exists() + df = pd.read_csv(out_file) + assert len(df) == 3 + assert {"time", "status"}.issubset(df.columns) diff --git a/tests/test_cmm.py b/tests/test_cmm.py index 48f2467..a34a68b 100644 --- a/tests/test_cmm.py +++ b/tests/test_cmm.py @@ -1,10 +1,87 @@ -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from gen_surv.cmm import gen_cmm - -def test_gen_cmm_shape(): - df = gen_cmm(n=50, model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], - covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0]) - assert df.shape[1] == 6 - assert "transition" in df.columns +import numpy as np +import pandas as pd + +from gen_surv.cmm import gen_cmm, generate_event_times + + +def test_generate_event_times_reproducible(): + rng = np.random.default_rng(0) + result = generate_event_times( + z1=1.0, + beta=[0.1, 0.2, 0.3], + rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rng=rng, + ) + assert np.isclose(result["t12"], 0.9168237140025525) + assert np.isclose(result["t13"], 0.2574241891031173) + assert np.isclose(result["t23"], 0.030993312969869156) + + +def test_gen_cmm_uniform_reproducible(): + df = gen_cmm( + n=5, + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + covariate_range=2.0, + rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + seed=42, + ) + expected = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "start": [0.0] * 5, + "stop": [ + 0.18915094163423693, + 0.6785349983450479, + 0.046776460564183294, + 0.12811363267554587, + 0.45038631001973155, + ], + "status": [1, 1, 1, 0, 0], + "X0": [ + 1.5479119272347037, + 0.8777564989945617, + 1.7171958398225217, + 1.3947360581187287, + 0.1883555828087116, + ], + "transition": [2.0, 2.0, 2.0, float("nan"), float("nan")], + } + ) + pd.testing.assert_frame_equal(df, expected) + + +def test_gen_cmm_exponential_reproducible(): + df = gen_cmm( + n=5, + model_cens="exponential", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + covariate_range=2.0, + rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + seed=42, + ) + expected = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "start": [0.0] * 5, + "stop": [ + 0.18915094163423693, + 0.6785349983450479, + 0.046776460564183294, + 0.07929383504134148, + 0.5750008479681584, + ], + "status": [1, 1, 1, 0, 1], + "X0": [ + 1.5479119272347037, + 0.8777564989945617, + 1.7171958398225217, + 1.3947360581187287, + 0.1883555828087116, + ], + "transition": [2.0, 2.0, 2.0, float("nan"), 1.0], + } + ) + pd.testing.assert_frame_equal(df, expected) diff --git a/tests/test_competing_risks.py b/tests/test_competing_risks.py new file mode 100644 index 0000000..d2eb326 --- /dev/null +++ b/tests/test_competing_risks.py @@ -0,0 +1,330 @@ +"""Tests for Competing Risks models.""" + +import os + +import numpy as np +import pandas as pd +import pytest +from hypothesis import given +from hypothesis import strategies as st + +import gen_surv.competing_risks as cr +from gen_surv.competing_risks import ( + cause_specific_cumulative_incidence, + gen_competing_risks, + gen_competing_risks_weibull, +) +from gen_surv.validation import ( + ChoiceError, + LengthError, + NumericSequenceError, + ParameterError, + PositiveSequenceError, + PositiveValueError, +) + +os.environ.setdefault("MPLBACKEND", "Agg") + + +def test_gen_competing_risks_basic(): + """Test that the competing risks generator runs without errors.""" + df = gen_competing_risks( + n=10, + n_risks=2, + baseline_hazards=[0.5, 0.3], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42, + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1, 2}) + + +def test_gen_competing_risks_weibull_basic(): + """Test that the Weibull competing risks generator runs without errors.""" + df = gen_competing_risks_weibull( + n=10, + n_risks=2, + shape_params=[0.8, 1.5], + scale_params=[2.0, 3.0], + betas=[[0.8, -0.5], [0.2, 0.7]], + model_cens="uniform", + cens_par=2.0, + seed=42, + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "time" in df.columns + assert "status" in df.columns + assert "X0" in df.columns + assert "X1" in df.columns + assert set(df["status"].unique()).issubset({0, 1, 2}) + + +def test_competing_risks_parameters(): + """Test parameter validation in competing risks model.""" + # Test with invalid number of baseline hazards + with pytest.raises(LengthError): + gen_competing_risks( + n=10, + n_risks=3, + baseline_hazards=[0.5, 0.3], # Only 2 provided, but 3 risks + seed=42, + ) + + # Test with invalid number of beta coefficient sets + with pytest.raises(LengthError): + gen_competing_risks( + n=10, + n_risks=2, + betas=[[0.8, -0.5]], # Only 1 set provided, but 2 risks + seed=42, + ) + + # Test with invalid censoring model + with pytest.raises(ChoiceError): + gen_competing_risks(n=10, n_risks=2, model_cens="invalid", seed=42) + + +def test_invalid_covariate_dist(): + with pytest.raises(ChoiceError): + gen_competing_risks(n=5, n_risks=2, covariate_dist="unknown", seed=1) + with pytest.raises(ChoiceError): + gen_competing_risks_weibull(n=5, n_risks=2, covariate_dist="unknown", seed=1) + + +def test_competing_risks_positive_params(): + with pytest.raises(PositiveValueError): + gen_competing_risks(n=5, n_risks=2, cens_par=0.0, seed=0) + with pytest.raises(PositiveValueError): + gen_competing_risks(n=5, n_risks=2, max_time=-1.0, seed=0) + + +def test_competing_risks_invalid_covariate_params(): + with pytest.raises(ParameterError): + gen_competing_risks( + n=5, + n_risks=2, + covariate_dist="normal", + covariate_params={"mean": 0.0}, + seed=1, + ) + with pytest.raises(PositiveValueError): + gen_competing_risks( + n=5, + n_risks=2, + covariate_dist="normal", + covariate_params={"mean": 0.0, "std": -1.0}, + seed=1, + ) + with pytest.raises(ParameterError): + gen_competing_risks_weibull( + n=5, + n_risks=2, + covariate_dist="binary", + covariate_params={"p": 1.5}, + seed=1, + ) + + +def test_competing_risks_invalid_beta_values(): + with pytest.raises(NumericSequenceError): + gen_competing_risks(n=5, n_risks=2, betas=[[0.1, "x"], [0.2, 0.3]], seed=0) + with pytest.raises(NumericSequenceError): + gen_competing_risks_weibull( + n=5, n_risks=2, betas=[[0.1, np.nan], [0.2, 0.3]], seed=0 + ) + + +def test_competing_risks_weibull_parameters(): + """Test parameter validation in Weibull competing risks model.""" + # Test with invalid number of shape parameters + with pytest.raises(LengthError): + gen_competing_risks_weibull( + n=10, + n_risks=3, + shape_params=[0.8, 1.5], # Only 2 provided, but 3 risks + seed=42, + ) + + # Test with invalid number of scale parameters + with pytest.raises(LengthError): + gen_competing_risks_weibull( + n=10, + n_risks=3, + scale_params=[2.0, 3.0], # Only 2 provided, but 3 risks + seed=42, + ) + + +def test_competing_risks_weibull_positive_params(): + with pytest.raises(PositiveSequenceError): + gen_competing_risks_weibull(n=5, n_risks=2, shape_params=[1.0, -1.0], seed=0) + with pytest.raises(PositiveSequenceError): + gen_competing_risks_weibull(n=5, n_risks=2, scale_params=[2.0, 0.0], seed=0) + with pytest.raises(PositiveValueError): + gen_competing_risks_weibull(n=5, n_risks=2, cens_par=-1.0, seed=0) + with pytest.raises(PositiveValueError): + gen_competing_risks_weibull(n=5, n_risks=2, max_time=0.0, seed=0) + + +def test_cause_specific_cumulative_incidence(): + """Test the cause-specific cumulative incidence function.""" + # Generate some data + df = gen_competing_risks(n=50, n_risks=2, baseline_hazards=[0.5, 0.3], seed=42) + + # Calculate CIF for cause 1 + time_points = np.linspace(0, 5, 10) + cif = cause_specific_cumulative_incidence(df, time_points, cause=1) + + assert isinstance(cif, pd.DataFrame) + assert len(cif) == len(time_points) + assert "time" in cif.columns + assert "incidence" in cif.columns + assert (cif["incidence"] >= 0).all() + assert (cif["incidence"] <= 1).all() + assert cif["incidence"].is_monotonic_increasing + + # Test with invalid cause + with pytest.raises(ParameterError): + cause_specific_cumulative_incidence(df, time_points, cause=3) + + +def test_cause_specific_cumulative_incidence_handles_ties(): + df = pd.DataFrame( + { + "id": [0, 1, 2, 3], + "time": [1.0, 1.0, 2.0, 2.0], + "status": [1, 2, 1, 0], + } + ) + cif = cause_specific_cumulative_incidence(df, [1.0, 2.0], cause=1) + assert np.allclose(cif["incidence"].to_numpy(), [0.25, 0.5]) + + +def test_cause_specific_cumulative_incidence_bounds(): + df = gen_competing_risks(n=30, n_risks=2, seed=5) + max_time = df["time"].max() + time_points = [-1.0, 0.0, max_time + 1] + cif = cause_specific_cumulative_incidence(df, time_points, cause=1) + assert cif.iloc[0]["incidence"] == 0.0 + expected = cause_specific_cumulative_incidence(df, [max_time], cause=1).iloc[0][ + "incidence" + ] + assert cif.iloc[-1]["incidence"] == expected + + +@given( + n=st.integers(min_value=5, max_value=50), + n_risks=st.integers(min_value=2, max_value=4), + seed=st.integers(min_value=0, max_value=1000), +) +def test_competing_risks_properties(n, n_risks, seed): + """Property-based tests for the competing risks model.""" + df = gen_competing_risks(n=n, n_risks=n_risks, seed=seed) + + # Check basic properties + assert df.shape[0] == n + assert all(col in df.columns for col in ["id", "time", "status"]) + assert (df["time"] >= 0).all() + assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks + + # Count of each status + status_counts = df["status"].value_counts() + # There should be at least one of each status (including censoring) + # This might occasionally fail due to randomness, so we'll just check that + # we have at least 2 different status values + assert len(status_counts) >= 2 + + +@given( + n=st.integers(min_value=5, max_value=50), + n_risks=st.integers(min_value=2, max_value=4), + seed=st.integers(min_value=0, max_value=1000), +) +def test_competing_risks_weibull_properties(n, n_risks, seed): + """Property-based tests for the Weibull competing risks model.""" + df = gen_competing_risks_weibull(n=n, n_risks=n_risks, seed=seed) + + # Check basic properties + assert df.shape[0] == n + assert all(col in df.columns for col in ["id", "time", "status"]) + assert (df["time"] >= 0).all() + assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks + + # Count of each status + status_counts = df["status"].value_counts() + # There should be at least 2 different status values + assert len(status_counts) >= 2 + + +def test_gen_competing_risks_forces_event_types(): + df = gen_competing_risks( + n=2, + n_risks=2, + baseline_hazards=[1e-9, 1e-9], + model_cens="uniform", + cens_par=0.1, + seed=0, + ) + assert set(df["status"]) == {1, 2} + + +def test_gen_competing_risks_weibull_forces_event_types(): + df = gen_competing_risks_weibull( + n=2, + n_risks=2, + shape_params=[1, 1], + scale_params=[1e9, 1e9], + model_cens="uniform", + cens_par=0.1, + seed=0, + ) + assert set(df["status"]) == {1, 2} + + +def test_reproducibility(): + """Test that results are reproducible with the same seed.""" + df1 = gen_competing_risks(n=20, n_risks=2, seed=42) + + df2 = gen_competing_risks(n=20, n_risks=2, seed=42) + + pd.testing.assert_frame_equal(df1, df2) + + # Different seeds should produce different results + df3 = gen_competing_risks(n=20, n_risks=2, seed=43) + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df3) + + +def test_competing_risks_summary_basic(): + df = gen_competing_risks(n=10, n_risks=2, seed=1) + summary = cr.competing_risks_summary(df) + assert summary["n_subjects"] == 10 + assert summary["n_causes"] == 2 + assert set(summary["events_by_cause"]) <= {1, 2} + assert "time_stats" in summary + + +def test_competing_risks_summary_with_categorical(): + df = gen_competing_risks(n=8, n_risks=2, seed=2) + df["group"] = ["A", "B"] * 4 + summary = cr.competing_risks_summary(df, covariate_cols=["X0", "group"]) + assert summary["covariate_stats"]["group"]["categories"] == 2 + assert "distribution" in summary["covariate_stats"]["group"] + + +def test_plot_cause_specific_hazards_runs(): + plt = pytest.importorskip("matplotlib.pyplot") + df = gen_competing_risks(n=30, n_risks=2, seed=3) + fig, ax = cr.plot_cause_specific_hazards(df, time_points=np.linspace(0, 5, 5)) + assert hasattr(fig, "savefig") + assert len(ax.get_lines()) >= 1 + plt.close(fig) diff --git a/tests/test_cphm.py b/tests/test_cphm.py index 05cc652..f71d04c 100644 --- a/tests/test_cphm.py +++ b/tests/test_cphm.py @@ -1,13 +1,66 @@ -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +""" +Tests for the Cox Proportional Hazards Model (CPHM) generator. +""" + +import pandas as pd +import pytest + from gen_surv.cphm import gen_cphm + def test_gen_cphm_output_shape(): - df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) + """Test that the output DataFrame has the expected shape and columns.""" + df = gen_cphm( + n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0 + ) assert df.shape == (50, 3) - assert list(df.columns) == ["time", "status", "covariate"] + assert list(df.columns) == ["time", "status", "X0"] + def test_gen_cphm_status_range(): - df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covar=1.5) + """Test that status values are binary (0 or 1).""" + df = gen_cphm( + n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covariate_range=1.5 + ) assert df["status"].isin([0, 1]).all() + + +def test_gen_cphm_time_positive(): + """Test that all time values are positive.""" + df = gen_cphm( + n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0 + ) + assert (df["time"] > 0).all() + + +def test_gen_cphm_covariate_range(): + """Test that covariate values are within the specified range.""" + covar_max = 2.5 + df = gen_cphm( + n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=covar_max + ) + assert (df["X0"] >= 0).all() + assert (df["X0"] <= covar_max).all() + + +def test_gen_cphm_seed_reproducibility(): + """Test that setting the same seed produces identical results.""" + df1 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 + ) + df2 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 + ) + pd.testing.assert_frame_equal(df1, df2) + + +def test_gen_cphm_different_seeds(): + """Test that different seeds produce different results.""" + df1 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 + ) + df2 = gen_cphm( + n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=43 + ) + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(df1, df2) diff --git a/tests/test_export.py b/tests/test_export.py new file mode 100644 index 0000000..6028c77 --- /dev/null +++ b/tests/test_export.py @@ -0,0 +1,80 @@ +import pandas as pd +import pyreadr +import pytest + +from gen_surv.export import export_dataset +from gen_surv.validation import ChoiceError + + +@pytest.mark.parametrize( + "fmt, reader", + [ + ("csv", pd.read_csv), + ("feather", pd.read_feather), + ("ft", pd.read_feather), + ], +) +def test_export_dataset_formats(fmt, reader, tmp_path): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + out = tmp_path / f"data.{fmt}" + export_dataset(df, out) + assert out.exists() + result = reader(out).astype(df.dtypes.to_dict()) + pd.testing.assert_frame_equal(result.reset_index(drop=True), df) + + +def test_export_dataset_json(monkeypatch, tmp_path): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + out = tmp_path / "data.json" + + called = {} + + def fake_to_json(self, path, orient="table"): + called["args"] = (path, orient) + with open(path, "w", encoding="utf-8") as f: + f.write("{}") + + monkeypatch.setattr(pd.DataFrame, "to_json", fake_to_json) + export_dataset(df, out) + assert called["args"] == (out, "table") + assert out.exists() + + +def test_export_dataset_rds(monkeypatch, tmp_path): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + out = tmp_path / "data.rds" + + captured = {} + + def fake_write_rds(path, data): + captured["path"] = path + captured["data"] = data + open(path, "wb").close() + + monkeypatch.setattr(pyreadr, "write_rds", fake_write_rds) + export_dataset(df, out) + assert out.exists() + pd.testing.assert_frame_equal(captured["data"], df.reset_index(drop=True)) + + +def test_export_dataset_explicit_fmt(monkeypatch, tmp_path): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + out = tmp_path / "data.bin" + + called = {} + + def fake_to_json(self, path, orient="table"): + called["args"] = (path, orient) + with open(path, "w", encoding="utf-8") as f: + f.write("{}") + + monkeypatch.setattr(pd.DataFrame, "to_json", fake_to_json) + export_dataset(df, out, fmt="json") + assert called["args"] == (out, "table") + assert out.exists() + + +def test_export_dataset_invalid_format(tmp_path): + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + with pytest.raises(ChoiceError): + export_dataset(df, tmp_path / "data.xxx", fmt="txt") diff --git a/tests/test_integration_sksurv.py b/tests/test_integration_sksurv.py new file mode 100644 index 0000000..84697d7 --- /dev/null +++ b/tests/test_integration_sksurv.py @@ -0,0 +1,13 @@ +import pandas as pd +import pytest + +from gen_surv.integration import to_sksurv + + +def test_to_sksurv(): + # Optional integration test; skipped when scikit-survival is not installed. + pytest.importorskip("sksurv.util") + df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) + arr = to_sksurv(df) + assert arr.dtype.names == ("status", "time") + assert arr.shape[0] == 2 diff --git a/tests/test_interface.py b/tests/test_interface.py index ad2dc3c..ece56c7 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,6 +1,8 @@ -from gen_surv import generate import pytest +from gen_surv import generate +from gen_surv.validation import ValidationError + def test_generate_tdcm_runs(): df = generate( @@ -20,3 +22,16 @@ def test_generate_tdcm_runs(): def test_generate_invalid_model(): with pytest.raises(ValueError): generate(model="unknown") + + +def test_generate_error_message_includes_model(): + with pytest.raises(ValidationError) as exc: + generate( + model="cphm", + n=0, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + assert "model 'cphm'" in str(exc.value) diff --git a/tests/test_mixture.py b/tests/test_mixture.py new file mode 100644 index 0000000..2b1a4f9 --- /dev/null +++ b/tests/test_mixture.py @@ -0,0 +1,83 @@ +import pandas as pd +import pytest + +from gen_surv.mixture import cure_fraction_estimate, gen_mixture_cure + + +def test_gen_mixture_cure_runs(): + df = gen_mixture_cure(n=10, cure_fraction=0.3, seed=42) + assert isinstance(df, pd.DataFrame) + assert len(df) == 10 + assert {"time", "status", "cured"}.issubset(df.columns) + + +def test_cure_fraction_estimate_range(): + df = gen_mixture_cure(n=50, cure_fraction=0.3, seed=0) + est = cure_fraction_estimate(df) + assert 0 <= est <= 1 + + +def test_cure_fraction_estimate_empty_returns_zero(): + df = pd.DataFrame(columns=["time", "status"]) + est = cure_fraction_estimate(df) + assert est == 0.0 + + +def test_gen_mixture_cure_invalid_inputs(): + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=1.5) + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=0.2, baseline_hazard=0) + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=0.2, covariate_dist="bad") + with pytest.raises(ValueError): + gen_mixture_cure(n=5, cure_fraction=0.2, model_cens="bad") + with pytest.raises(ValueError): + gen_mixture_cure( + n=5, + cure_fraction=0.2, + betas_survival=[0.1, 0.2], + betas_cure=[0.1], + ) + + +def test_gen_mixture_cure_max_time_cap(): + df = gen_mixture_cure( + n=50, + cure_fraction=0.3, + max_time=5.0, + model_cens="exponential", + cens_par=1.0, + seed=123, + ) + assert (df["time"] <= 5.0).all() + + +def test_cure_fraction_estimate_close_to_true(): + df = gen_mixture_cure(n=200, cure_fraction=0.4, seed=1) + est = cure_fraction_estimate(df) + assert pytest.approx(0.4, abs=0.15) == est + + +def test_gen_mixture_cure_covariate_distributions(): + for dist in ["uniform", "binary"]: + df = gen_mixture_cure(n=20, cure_fraction=0.3, covariate_dist=dist, seed=2) + assert {"time", "status", "cured", "X0", "X1"}.issubset(df.columns) + + +def test_gen_mixture_cure_no_max_time_allows_long_times(): + df = gen_mixture_cure( + n=100, + cure_fraction=0.5, + max_time=None, + model_cens="uniform", + cens_par=20.0, + seed=3, + ) + assert df["time"].max() > 10 + + +def test_cure_fraction_estimate_small_sample(): + df = gen_mixture_cure(n=3, cure_fraction=0.2, seed=4) + est = cure_fraction_estimate(df) + assert 0 <= est <= 1 diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py new file mode 100644 index 0000000..61b75d3 --- /dev/null +++ b/tests/test_piecewise.py @@ -0,0 +1,104 @@ +import pandas as pd +import pytest + +from gen_surv.piecewise import gen_piecewise_exponential + + +def test_gen_piecewise_exponential_runs(): + df = gen_piecewise_exponential( + n=10, breakpoints=[1.0], hazard_rates=[0.5, 1.0], seed=42 + ) + assert isinstance(df, pd.DataFrame) + assert len(df) == 10 + assert {"time", "status"}.issubset(df.columns) + + +def test_piecewise_invalid_lengths(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42 + ) + + +def test_piecewise_invalid_hazard_and_breakpoints(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[2.0, 1.0], + hazard_rates=[0.5, 1.0, 1.5], + seed=42, + ) + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.5, -1.0], + seed=42, + ) + + +def test_piecewise_covariate_distributions(): + for dist, params in [ + ("uniform", {"low": 0.0, "high": 1.0}), + ("binary", {"p": 0.7}), + ]: + df = gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.2, 0.4], + covariate_dist=dist, + covariate_params=params, + seed=1, + ) + assert len(df) == 5 + assert {"X0", "X1"}.issubset(df.columns) + + +def test_piecewise_custom_betas_reproducible(): + df1 = gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.1, 0.2], + betas=[0.5, -0.2], + seed=2, + ) + df2 = gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.1, 0.2], + betas=[0.5, -0.2], + seed=2, + ) + pd.testing.assert_frame_equal(df1, df2) + + +def test_piecewise_invalid_covariate_dist(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.5, 1.0], + covariate_dist="unknown", + seed=1, + ) + + +def test_piecewise_invalid_censoring_model(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[1.0], + hazard_rates=[0.5, 1.0], + model_cens="bad", + seed=1, + ) + + +def test_piecewise_negative_breakpoint(): + with pytest.raises(ValueError): + gen_piecewise_exponential( + n=5, + breakpoints=[-1.0], + hazard_rates=[0.5, 1.0], + seed=1, + ) diff --git a/tests/test_piecewise_functions.py b/tests/test_piecewise_functions.py new file mode 100644 index 0000000..b1de35a --- /dev/null +++ b/tests/test_piecewise_functions.py @@ -0,0 +1,52 @@ +import numpy as np + +from gen_surv.piecewise import piecewise_hazard_function, piecewise_survival_function + + +def test_piecewise_hazard_function_scalar_and_array(): + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + # Scalar values + assert piecewise_hazard_function(0.5, breakpoints, hazard_rates) == 0.5 + assert piecewise_hazard_function(1.5, breakpoints, hazard_rates) == 1.0 + assert piecewise_hazard_function(3.0, breakpoints, hazard_rates) == 1.5 + # Array values + arr = np.array([0.5, 1.5, 3.0]) + np.testing.assert_allclose( + piecewise_hazard_function(arr, breakpoints, hazard_rates), + np.array([0.5, 1.0, 1.5]), + ) + + +def test_piecewise_hazard_function_negative_time(): + """Hazard should be zero for negative times.""" + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + assert piecewise_hazard_function(-1.0, breakpoints, hazard_rates) == 0 + np.testing.assert_array_equal( + piecewise_hazard_function(np.array([-0.5, -2.0]), breakpoints, hazard_rates), + np.array([0.0, 0.0]), + ) + + +def test_piecewise_survival_function(): + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + # Known survival probabilities + expected = np.exp(-np.array([0.0, 0.25, 1.0, 3.0])) + times = np.array([0.0, 0.5, 1.5, 3.0]) + np.testing.assert_allclose( + piecewise_survival_function(times, breakpoints, hazard_rates), + expected, + ) + + +def test_piecewise_survival_function_scalar_and_negative(): + breakpoints = [1.0, 2.0] + hazard_rates = [0.5, 1.0, 1.5] + # Scalar output should be a float + val = piecewise_survival_function(1.5, breakpoints, hazard_rates) + assert isinstance(val, float) + assert np.isclose(val, np.exp(-1.0)) + # Negative times return survival of 1 + assert piecewise_survival_function(-2.0, breakpoints, hazard_rates) == 1 diff --git a/tests/test_sklearn_adapter.py b/tests/test_sklearn_adapter.py new file mode 100644 index 0000000..f4b8c4a --- /dev/null +++ b/tests/test_sklearn_adapter.py @@ -0,0 +1,47 @@ +import pytest + +from gen_surv.sklearn_adapter import GenSurvDataGenerator +from gen_surv.validation import ChoiceError + + +def test_sklearn_generator_dataframe(): + gen = GenSurvDataGenerator( + "cphm", + n=4, + beta=0.2, + covariate_range=1.0, + model_cens="uniform", + cens_par=1.0, + ) + df = gen.fit_transform() + assert len(df) == 4 + assert {"time", "status"}.issubset(df.columns) + + +def test_sklearn_generator_dict(): + gen = GenSurvDataGenerator( + "cphm", + return_type="dict", + n=3, + beta=0.5, + covariate_range=1.0, + model_cens="uniform", + cens_par=1.0, + ) + data = gen.transform() + assert isinstance(data, dict) + assert set(data.keys()) >= {"time", "status"} + assert len(data["time"]) == 3 + + +def test_sklearn_generator_invalid_return_type(): + with pytest.raises(ChoiceError): + GenSurvDataGenerator( + "cphm", + return_type="bad", + n=1, + beta=0.5, + covariate_range=1.0, + model_cens="uniform", + cens_par=1.0, + ) diff --git a/tests/test_summary.py b/tests/test_summary.py new file mode 100644 index 0000000..cf63caf --- /dev/null +++ b/tests/test_summary.py @@ -0,0 +1,16 @@ +from gen_surv import generate +from gen_surv.summary import summarize_survival_dataset + + +def test_summarize_survival_dataset_basic(): + df = generate( + model="cphm", + n=20, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + summary = summarize_survival_dataset(df, verbose=False) + assert isinstance(summary, dict) + assert "dataset_info" in summary diff --git a/tests/test_summary_extra.py b/tests/test_summary_extra.py new file mode 100644 index 0000000..380b818 --- /dev/null +++ b/tests/test_summary_extra.py @@ -0,0 +1,104 @@ +import pandas as pd +import pytest + +from gen_surv import generate +from gen_surv.summary import ( + _print_summary, + check_survival_data_quality, + compare_survival_datasets, +) + + +def test_check_survival_data_quality_fix_issues(): + df = pd.DataFrame( + { + "time": [1.0, -0.5, None, 1.0], + "status": [1, 2, 0, 1], + "id": [1, 2, 3, 1], + } + ) + fixed, issues = check_survival_data_quality( + df, + id_col="id", + max_time=2.0, + fix_issues=True, + ) + assert issues["modifications"]["rows_dropped"] == 2 + assert issues["modifications"]["values_fixed"] == 1 + assert len(fixed) == 2 + + +def test_check_survival_data_quality_no_fix(): + """Issues should be reported but data left unchanged when fix_issues=False.""" + df = pd.DataFrame({"time": [-1.0, 2.0], "status": [3, 1]}) + checked, issues = check_survival_data_quality(df, max_time=1.0, fix_issues=False) + # Data is returned unmodified + pd.testing.assert_frame_equal(df, checked) + assert issues["invalid_values"]["negative_time"] == 1 + assert issues["invalid_values"]["excessive_time"] == 1 + assert issues["invalid_values"]["invalid_status"] == 1 + + +def test_compare_survival_datasets_basic(): + ds1 = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) + ds2 = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=1.0, + covariate_range=1.0, + ) + comparison = compare_survival_datasets({"A": ds1, "B": ds2}) + assert set(["A", "B"]).issubset(comparison.columns) + assert "n_subjects" in comparison.index + + +def test_compare_survival_datasets_with_covariates_and_empty_error(): + ds = generate( + model="cphm", + n=3, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) + comparison = compare_survival_datasets({"only": ds}, covariate_cols=["X0"]) + assert "only" in comparison.columns + assert "X0_mean" in comparison.index + with pytest.raises(ValueError): + compare_survival_datasets({}) + + +def test_check_survival_data_quality_min_and_max(): + df = pd.DataFrame({"time": [-1.0, 3.0], "status": [1, 1]}) + fixed, issues = check_survival_data_quality( + df, min_time=0.0, max_time=2.0, fix_issues=True + ) + assert (fixed["time"] <= 2.0).all() + assert issues["modifications"]["values_fixed"] > 0 + + +def test_print_summary_with_issues(capsys): + summary = { + "dataset_info": {"n_subjects": 2, "n_unique_ids": 2, "n_covariates": 0}, + "event_info": {"n_events": 1, "n_censored": 1, "event_rate": 0.5}, + "time_info": {"min": 0.0, "max": 2.0, "mean": 1.0, "median": 1.0}, + "data_quality": { + "missing_time": 0, + "missing_status": 0, + "negative_time": 1, + "invalid_status": 0, + }, + "covariates": {}, + } + _print_summary(summary, "time", "status", None, []) + out = capsys.readouterr().out + assert "Issues detected" in out diff --git a/tests/test_summary_more.py b/tests/test_summary_more.py new file mode 100644 index 0000000..8894237 --- /dev/null +++ b/tests/test_summary_more.py @@ -0,0 +1,60 @@ +import pandas as pd +import pytest + +from gen_surv.summary import ( + _print_summary, + check_survival_data_quality, + summarize_survival_dataset, +) +from gen_surv.validation import ParameterError + + +def test_summarize_survival_dataset_errors(): + df = pd.DataFrame({"time": [1, 2], "status": [1, 0]}) + # Missing time column + with pytest.raises(ParameterError): + summarize_survival_dataset(df.drop(columns=["time"])) + # Missing ID column when specified + with pytest.raises(ParameterError): + summarize_survival_dataset(df, id_col="id") + # Missing covariate columns + with pytest.raises(ParameterError): + summarize_survival_dataset(df, covariate_cols=["bad"]) + + +def test_summarize_survival_dataset_verbose_output(capsys): + df = pd.DataFrame( + { + "time": [1.0, 2.0, 3.0], + "status": [1, 0, 1], + "id": [1, 2, 3], + "age": [30, 40, 50], + "group": ["A", "B", "A"], + } + ) + summary = summarize_survival_dataset( + df, id_col="id", covariate_cols=["age", "group"] + ) + _print_summary(summary, "time", "status", "id", ["age", "group"]) + captured = capsys.readouterr().out + assert "SURVIVAL DATASET SUMMARY" in captured + assert "age:" in captured + assert "Categorical" in captured + + +def test_check_survival_data_quality_duplicates_and_fix(): + df = pd.DataFrame( + { + "time": [1.0, -1.0, 2.0, 1.0], + "status": [1, 1, 0, 1], + "id": [1, 1, 2, 1], + } + ) + checked, issues = check_survival_data_quality(df, id_col="id", fix_issues=False) + assert issues["duplicates"]["duplicate_rows"] == 1 + assert issues["duplicates"]["duplicate_ids"] == 2 + fixed, issues_fixed = check_survival_data_quality( + df, id_col="id", max_time=2.0, fix_issues=True + ) + assert len(fixed) < len(df) + assert issues_fixed["modifications"]["rows_dropped"] > 0 diff --git a/tests/test_tdcm.py b/tests/test_tdcm.py index 507b51f..b393a3b 100644 --- a/tests/test_tdcm.py +++ b/tests/test_tdcm.py @@ -1,10 +1,16 @@ -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.tdcm import gen_tdcm + def test_gen_tdcm_shape(): - df = gen_tdcm(n=50, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2], - model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0) + df = gen_tdcm( + n=50, + dist="weibull", + corr=0.5, + dist_par=[1, 2, 1, 2], + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) assert df.shape[1] == 6 assert "tdcov" in df.columns diff --git a/tests/test_thmm.py b/tests/test_thmm.py index b53b197..e91d2a1 100644 --- a/tests/test_thmm.py +++ b/tests/test_thmm.py @@ -1,11 +1,14 @@ -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - from gen_surv.thmm import gen_thmm + def test_gen_thmm_shape(): - df = gen_thmm(n=50, model_cens="uniform", cens_par=1.0, - beta=[0.1, 0.2, 0.3], covar=2.0, rate=[0.5, 0.6, 0.7]) + df = gen_thmm( + n=50, + model_cens="uniform", + cens_par=1.0, + beta=[0.1, 0.2, 0.3], + covariate_range=2.0, + rate=[0.5, 0.6, 0.7], + ) assert df.shape[1] == 4 assert set(df["state"].unique()).issubset({1, 2, 3}) diff --git a/tests/test_validate.py b/tests/test_validate.py index ffbbad0..4151d76 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,5 +1,16 @@ +import numpy as np import pytest -import gen_surv.validate as v + +import gen_surv.validation as v +from gen_surv.validation import ( + ChoiceError, + ParameterError, + PositiveIntegerError, + PositiveValueError, + ensure_censoring_model, + ensure_positive, + ensure_positive_int, +) def test_validate_gen_cphm_inputs_valid(): @@ -8,7 +19,7 @@ def test_validate_gen_cphm_inputs_valid(): @pytest.mark.parametrize( - "n, model_cens, cens_par, covar", + "n, model_cens, cens_par, covariate_range", [ (0, "uniform", 0.5, 1.0), (1, "bad", 0.5, 1.0), @@ -16,10 +27,10 @@ def test_validate_gen_cphm_inputs_valid(): (1, "uniform", 0.5, -1.0), ], ) -def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covar): +def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covariate_range): """Invalid parameter combinations should raise ValueError.""" with pytest.raises(ValueError): - v.validate_gen_cphm_inputs(n, model_cens, cens_par, covar) + v.validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) def test_validate_dg_biv_inputs_invalid(): @@ -36,11 +47,36 @@ def test_validate_gen_cmm_inputs_invalid_beta_length(): "uniform", 0.5, [0.1, 0.2], - covar=1.0, + covariate_range=1.0, rate=[0.1] * 6, ) +@pytest.mark.parametrize( + "n, model_cens, cens_par, cov_range, rate", + [ + (0, "uniform", 0.5, 1.0, [0.1] * 6), + (1, "bad", 0.5, 1.0, [0.1] * 6), + (1, "uniform", 0.0, 1.0, [0.1] * 6), + (1, "uniform", 0.5, 0.0, [0.1] * 6), + (1, "uniform", 0.5, 1.0, [0.1] * 3), + ], +) +def test_validate_gen_cmm_inputs_other_invalid( + n, model_cens, cens_par, cov_range, rate +): + with pytest.raises(ValueError): + v.validate_gen_cmm_inputs( + n, model_cens, cens_par, [0.1, 0.2, 0.3], cov_range, rate + ) + + +def test_validate_gen_cmm_inputs_valid(): + v.validate_gen_cmm_inputs( + 1, "uniform", 1.0, [0.1, 0.2, 0.3], covariate_range=1.0, rate=[0.1] * 6 + ) + + def test_validate_gen_tdcm_inputs_invalid_lambda(): """Lambda <= 0 should raise a ValueError.""" with pytest.raises(ValueError): @@ -56,6 +92,44 @@ def test_validate_gen_tdcm_inputs_invalid_lambda(): ) +@pytest.mark.parametrize( + "dist,corr,dist_par", + [ + ("bad", 0.5, [1, 2]), + ("weibull", 0.0, [1, 2, 3, 4]), + ("weibull", 0.5, [1, 2, -1, 2]), + ("weibull", 0.5, [1, 2, 3]), + ("exponential", 2.0, [1, 1]), + ("exponential", 0.5, [1]), + ], +) +def test_validate_gen_tdcm_inputs_invalid_dist(dist, corr, dist_par): + with pytest.raises(ValueError): + v.validate_gen_tdcm_inputs( + 1, + dist, + corr, + dist_par, + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) + + +def test_validate_gen_tdcm_inputs_valid(): + v.validate_gen_tdcm_inputs( + 1, + "weibull", + 0.5, + [1, 1, 1, 1], + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) + + def test_validate_gen_aft_log_normal_inputs_valid(): """Valid parameters should not raise an error for AFT log-normal.""" v.validate_gen_aft_log_normal_inputs( @@ -67,6 +141,177 @@ def test_validate_gen_aft_log_normal_inputs_valid(): ) +@pytest.mark.parametrize( + "n,beta,sigma,model_cens,cens_par", + [ + (0, [0.1], 1.0, "uniform", 1.0), + (1, "bad", 1.0, "uniform", 1.0), + (1, [0.1], 0.0, "uniform", 1.0), + (1, [0.1], 1.0, "bad", 1.0), + (1, [0.1], 1.0, "uniform", 0.0), + ], +) +def test_validate_gen_aft_log_normal_inputs_invalid( + n, beta, sigma, model_cens, cens_par +): + with pytest.raises(ValueError): + v.validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par) + + def test_validate_dg_biv_inputs_valid_weibull(): """Valid parameters for a Weibull distribution should pass.""" v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0]) + + +def test_validate_dg_biv_inputs_invalid_corr_and_params(): + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(1, "exponential", -2.0, [1.0, 1.0]) + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(1, "exponential", 0.5, [1.0]) + with pytest.raises(ValueError): + v.validate_dg_biv_inputs(1, "weibull", 0.5, [1.0, 1.0]) + + +def test_validate_gen_aft_weibull_inputs_and_log_logistic(): + with pytest.raises(ValueError): + v.validate_gen_aft_weibull_inputs(0, [0.1], 1.0, 1.0, "uniform", 1.0) + with pytest.raises(ValueError): + v.validate_gen_aft_log_logistic_inputs(1, [0.1], -1.0, 1.0, "uniform", 1.0) + + +@pytest.mark.parametrize( + "shape,scale", + [(-1.0, 1.0), (1.0, -1.0)], +) +def test_validate_gen_aft_weibull_invalid_params(shape, scale): + with pytest.raises(ValueError): + v.validate_gen_aft_weibull_inputs(1, [0.1], shape, scale, "uniform", 1.0) + + +def test_validate_gen_aft_weibull_valid(): + v.validate_gen_aft_weibull_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0) + + +def test_validate_gen_aft_log_logistic_valid(): + v.validate_gen_aft_log_logistic_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0) + + +def test_positive_sequence_nan_inf(): + with pytest.raises(v.PositiveSequenceError): + v.ensure_positive_sequence([1.0, float("nan")], "x") + with pytest.raises(v.PositiveSequenceError): + v.ensure_positive_sequence([1.0, float("inf")], "x") + + +def test_numeric_sequence_rejects_bool(): + with pytest.raises(v.NumericSequenceError): + v.ensure_numeric_sequence([1, True], "x") + + +def test_validate_competing_risks_inputs(): + with pytest.raises(ValueError): + v.validate_competing_risks_inputs( + 1, 2, [0.1], None, "normal", None, "uniform", 1.0 + ) + with pytest.raises(v.ChoiceError): + v.validate_competing_risks_inputs( + 1, 1, [0.5], [[0.1]], "gaussian", None, "uniform", 0.5 + ) + v.validate_competing_risks_inputs( + 1, 1, [0.5], [[0.1]], "normal", 10.0, "uniform", 0.5 + ) + + +@pytest.mark.parametrize( + "n,model_cens,cens_par,beta,cov_range,rate", + [ + (0, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), + (1, "bad", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), + (1, "uniform", 0.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), + (1, "uniform", 1.0, [0.1, 0.2], 1.0, [0.1, 0.2, 0.3]), + (1, "uniform", 1.0, [0.1, 0.2, 0.3], 0.0, [0.1, 0.2, 0.3]), + (1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1]), + ], +) +def test_validate_gen_thmm_inputs_invalid( + n, model_cens, cens_par, beta, cov_range, rate +): + with pytest.raises(ValueError): + v.validate_gen_thmm_inputs(n, model_cens, cens_par, beta, cov_range, rate) + + +def test_validate_gen_thmm_inputs_valid(): + v.validate_gen_thmm_inputs(1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]) + + +def test_positive_integer_error(): + with pytest.raises(PositiveIntegerError): + ensure_positive_int(-1, "n") + + +def test_ensure_positive_int_accepts_numpy_and_rejects_bool(): + ensure_positive_int(np.int64(5), "n") + with pytest.raises(PositiveIntegerError): + ensure_positive_int(True, "n") + + +def test_ensure_positive_accepts_numpy_and_rejects_bool(): + ensure_positive(np.float64(0.1), "val") + with pytest.raises(PositiveValueError): + ensure_positive(True, "val") + + +def test_censoring_model_choice_error(): + with pytest.raises(ChoiceError): + ensure_censoring_model("bad") + + +def test_parameter_error_from_validator(): + with pytest.raises(ParameterError): + v.validate_gen_tdcm_inputs( + 1, + "weibull", + 0.0, + [1, 2, 3, 4], + "uniform", + 1.0, + beta=[0.1, 0.2, 0.3], + lam=1.0, + ) + + +def test_validate_gen_piecewise_inputs_invalid(): + with pytest.raises(ValueError): + v.validate_gen_piecewise_inputs( + 0, + [1.0], + [0.2], + 2, + "uniform", + 1.0, + "normal", + ) + + +def test_validate_gen_mixture_inputs_valid_and_invalid(): + v.validate_gen_mixture_inputs( + 10, + 0.3, + 0.5, + 2, + "uniform", + 5.0, + 10.0, + "normal", + ) + with pytest.raises(ValueError): + v.validate_gen_mixture_inputs( + 10, + 1.5, + 0.5, + 2, + "uniform", + 5.0, + 10.0, + "normal", + ) diff --git a/tests/test_version.py b/tests/test_version.py index 6fb4a02..aff52a9 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,4 +1,5 @@ from importlib.metadata import version + from gen_surv import __version__ diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..b3ac616 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,173 @@ +import pandas as pd +import pytest +import typer + +from gen_surv import generate +from gen_surv.cli import visualize +from gen_surv.visualization import ( + describe_survival, + plot_covariate_effect, + plot_hazard_comparison, + plot_survival_curve, +) + + +def test_plot_survival_curve_runs(): + df = generate( + model="cphm", + n=10, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + fig, ax = plot_survival_curve(df) + assert fig is not None + assert ax is not None + + +def test_plot_hazard_comparison_runs(): + df1 = generate( + model="cphm", + n=5, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=1.0, + ) + df2 = generate( + model="aft_weibull", + n=5, + beta=[0.5], + shape=1.5, + scale=2.0, + model_cens="uniform", + cens_par=1.0, + ) + models = {"cphm": df1, "aft_weibull": df2} + fig, ax = plot_hazard_comparison(models) + assert fig is not None + assert ax is not None + + +def test_plot_covariate_effect_runs(): + df = generate( + model="cphm", + n=10, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + fig, ax = plot_covariate_effect(df, covariate_col="X0", n_groups=2) + assert fig is not None + assert ax is not None + + +def test_describe_survival_summary(): + df = generate( + model="cphm", + n=10, + model_cens="uniform", + cens_par=1.0, + beta=0.5, + covariate_range=2.0, + ) + summary = describe_survival(df) + expected_metrics = [ + "Total Observations", + "Number of Events", + "Number Censored", + "Event Rate", + "Median Survival Time", + "Min Time", + "Max Time", + "Mean Time", + ] + assert list(summary["Metric"]) == expected_metrics + assert summary.shape[0] == len(expected_metrics) + + +def test_cli_visualize(tmp_path, capsys): + df = pd.DataFrame({"time": [1, 2, 3], "status": [1, 0, 1]}) + csv_path = tmp_path / "d.csv" + df.to_csv(csv_path, index=False) + out_file = tmp_path / "out.png" + visualize( + str(csv_path), + time_col="time", + status_col="status", + group_col=None, + output=str(out_file), + ) + assert out_file.exists() + captured = capsys.readouterr() + assert "Plot saved to" in captured.out + + +def test_cli_visualize_missing_column(tmp_path, capsys): + df = pd.DataFrame({"time": [1, 2], "event": [1, 0]}) + csv_path = tmp_path / "bad.csv" + df.to_csv(csv_path, index=False) + with pytest.raises(typer.Exit): + visualize( + str(csv_path), + time_col="time", + status_col="status", + group_col=None, + ) + captured = capsys.readouterr() + assert "Status column 'status' not found in data" in captured.out + + +def test_cli_visualize_missing_time(tmp_path, capsys): + df = pd.DataFrame({"t": [1, 2], "status": [1, 0]}) + path = tmp_path / "d.csv" + df.to_csv(path, index=False) + with pytest.raises(typer.Exit): + visualize(str(path), time_col="time", status_col="status") + captured = capsys.readouterr() + assert "Time column 'time' not found in data" in captured.out + + +def test_cli_visualize_missing_group(tmp_path, capsys): + df = pd.DataFrame({"time": [1], "status": [1], "x": [0]}) + path = tmp_path / "d2.csv" + df.to_csv(path, index=False) + with pytest.raises(typer.Exit): + visualize(str(path), time_col="time", status_col="status", group_col="group") + captured = capsys.readouterr() + assert "Group column 'group' not found in data" in captured.out + + +def test_cli_visualize_import_error(monkeypatch, tmp_path, capsys): + """visualize exits when matplotlib is missing.""" + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("matplotlib"): # simulate missing dependency + raise ImportError("no matplot") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + csv_path = tmp_path / "d.csv" + pd.DataFrame({"time": [1], "status": [1]}).to_csv(csv_path, index=False) + with pytest.raises(typer.Exit): + visualize(str(csv_path)) + captured = capsys.readouterr() + assert "Visualization requires matplotlib" in captured.out + + +def test_cli_visualize_read_error(monkeypatch, tmp_path, capsys): + """visualize handles CSV read failures gracefully.""" + monkeypatch.setattr( + "pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom")) + ) + csv_path = tmp_path / "x.csv" + csv_path.write_text("time,status\n1,1\n") + with pytest.raises(typer.Exit): + visualize(str(csv_path)) + captured = capsys.readouterr() + assert "Error loading CSV file" in captured.out