From 5442dedce9ee16292890e72219aa8d626574a7cf Mon Sep 17 00:00:00 2001 From: Diogo Ribeiro Date: Fri, 8 Aug 2025 14:50:03 +0100 Subject: [PATCH] Revert "Feat/remove ignore errors in mypy settings (#78)" This reverts commit 3c799d0f279e8f659d5f38a289ddf6b604718389. --- .flake8 | 3 - .github/workflows/bump-version.yml | 53 +- .github/workflows/ci.yml | 96 ---- .github/workflows/docs.yml | 34 -- .github/workflows/publish.yml | 27 - .github/workflows/test.yml | 38 ++ .gitignore | 1 - .pre-commit-config.yaml | 19 - .readthedocs.yml | 14 +- CHANGELOG.md | 35 -- CHECKLIST.md | 102 ---- CITATION.cff | 5 +- CONTRIBUTING.md | 11 +- LICENSE => LICENCE | 0 README.md | 221 ++++---- TODO.md | 132 +++-- benchmarks/README.md | 13 - benchmarks/test_tdcm_benchmark.py | 19 - benchmarks/test_validation_benchmark.py | 11 - binder/requirements.txt | 2 - docs/requirements.txt | 10 +- docs/source/_static/custom.css | 56 -- docs/source/algorithms.md | 48 -- docs/source/api/index.md | 89 --- docs/source/bibliography.md | 59 -- docs/source/changelog.md | 9 - docs/source/conf.py | 104 +--- docs/source/contributing.md | 9 - docs/source/examples/cmm.md | 28 - docs/source/examples/index.md | 25 - docs/source/examples/tdcm.md | 30 - docs/source/examples/thmm.md | 28 - docs/source/getting_started.md | 88 --- docs/source/index.md | 179 +++--- docs/source/modules.md | 17 +- docs/source/rtd.md | 20 - docs/source/theory.md | 46 +- docs/source/tutorials/basic_usage.md | 110 ---- docs/source/tutorials/index.md | 13 - docs/source/usage.md | 60 +- examples/notebooks/cmm.ipynb | 45 -- examples/notebooks/tdcm.ipynb | 47 -- examples/notebooks/thmm.ipynb | 45 -- examples/run_aft.py | 7 +- examples/run_aft_weibull.py | 94 ---- examples/run_cmm.py | 7 +- examples/run_competing_risks.py | 143 ----- examples/run_cphm.py | 9 +- examples/run_tdcm.py | 7 +- examples/run_thmm.py | 7 +- gen_surv/__init__.py | 86 +-- gen_surv/_covariates.py | 77 --- gen_surv/aft.py | 216 +------- gen_surv/bivariate.py | 59 +- gen_surv/censoring.py | 105 +--- gen_surv/cli.py | 217 +------- gen_surv/cmm.py | 95 ++-- gen_surv/competing_risks.py | 697 ------------------------ gen_surv/cphm.py | 123 ++--- gen_surv/export.py | 48 -- gen_surv/integration.py | 38 -- gen_surv/interface.py | 75 +-- gen_surv/mixture.py | 305 ----------- gen_surv/piecewise.py | 315 ----------- gen_surv/sklearn_adapter.py | 68 --- gen_surv/summary.py | 508 ----------------- gen_surv/tdcm.py | 97 ++-- gen_surv/thmm.py | 44 +- gen_surv/validate.py | 195 ++++++- gen_surv/validation.py | 415 -------------- gen_surv/visualization.py | 375 ------------- pyproject.toml | 99 +--- scripts/check_version_match.py | 12 +- tasks.py | 86 +-- tests/test_aft.py | 258 +-------- tests/test_aft_property.py | 17 +- tests/test_bivariate.py | 18 +- tests/test_censoring.py | 72 --- tests/test_cli.py | 206 +------ tests/test_cli_integration.py | 30 - tests/test_cmm.py | 96 +--- tests/test_competing_risks.py | 330 ----------- tests/test_cphm.py | 65 +-- tests/test_export.py | 80 --- tests/test_integration_sksurv.py | 13 - tests/test_interface.py | 3 +- tests/test_mixture.py | 83 --- tests/test_piecewise.py | 104 ---- tests/test_piecewise_functions.py | 52 -- tests/test_sklearn_adapter.py | 47 -- tests/test_summary.py | 16 - tests/test_summary_extra.py | 104 ---- tests/test_summary_more.py | 60 -- tests/test_tdcm.py | 16 +- tests/test_thmm.py | 15 +- tests/test_validate.py | 210 +------ tests/test_version.py | 1 - tests/test_visualization.py | 173 ------ 98 files changed, 888 insertions(+), 7811 deletions(-) delete mode 100644 .flake8 delete mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/docs.yml delete mode 100644 .github/workflows/publish.yml create mode 100644 .github/workflows/test.yml delete mode 100644 .pre-commit-config.yaml delete mode 100644 CHECKLIST.md rename LICENSE => LICENCE (100%) delete mode 100644 benchmarks/README.md delete mode 100644 benchmarks/test_tdcm_benchmark.py delete mode 100644 benchmarks/test_validation_benchmark.py delete mode 100644 binder/requirements.txt delete mode 100644 docs/source/_static/custom.css delete mode 100644 docs/source/algorithms.md delete mode 100644 docs/source/api/index.md delete mode 100644 docs/source/bibliography.md delete mode 100644 docs/source/changelog.md delete mode 100644 docs/source/contributing.md delete mode 100644 docs/source/examples/cmm.md delete mode 100644 docs/source/examples/index.md delete mode 100644 docs/source/examples/tdcm.md delete mode 100644 docs/source/examples/thmm.md delete mode 100644 docs/source/getting_started.md delete mode 100644 docs/source/rtd.md delete mode 100644 docs/source/tutorials/basic_usage.md delete mode 100644 docs/source/tutorials/index.md delete mode 100644 examples/notebooks/cmm.ipynb delete mode 100644 examples/notebooks/tdcm.ipynb delete mode 100644 examples/notebooks/thmm.ipynb delete mode 100644 examples/run_aft_weibull.py delete mode 100644 examples/run_competing_risks.py delete mode 100644 gen_surv/_covariates.py delete mode 100644 gen_surv/competing_risks.py delete mode 100644 gen_surv/export.py delete mode 100644 gen_surv/integration.py delete mode 100644 gen_surv/mixture.py delete mode 100644 gen_surv/piecewise.py delete mode 100644 gen_surv/sklearn_adapter.py delete mode 100644 gen_surv/summary.py delete mode 100644 gen_surv/validation.py delete mode 100644 gen_surv/visualization.py delete mode 100644 tests/test_censoring.py delete mode 100644 tests/test_cli_integration.py delete mode 100644 tests/test_competing_risks.py delete mode 100644 tests/test_export.py delete mode 100644 tests/test_integration_sksurv.py delete mode 100644 tests/test_mixture.py delete mode 100644 tests/test_piecewise.py delete mode 100644 tests/test_piecewise_functions.py delete mode 100644 tests/test_sklearn_adapter.py delete mode 100644 tests/test_summary.py delete mode 100644 tests/test_summary_extra.py delete mode 100644 tests/test_summary_more.py delete mode 100644 tests/test_visualization.py diff --git a/.flake8 b/.flake8 deleted file mode 100644 index c14e2f1..0000000 --- a/.flake8 +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -max-line-length = 120 -extend-ignore = E501,W291,W293,W391,E402,E302,E305 diff --git a/.github/workflows/bump-version.yml b/.github/workflows/bump-version.yml index f91dd32..5b41bd7 100644 --- a/.github/workflows/bump-version.yml +++ b/.github/workflows/bump-version.yml @@ -1,26 +1,67 @@ # .github/workflows/bump-version.yml -name: Tag Version on Merge to Main - +name: Bump Version on Merge to Main + on: push: branches: - main jobs: - tag-version: + bump-version: runs-on: ubuntu-latest permissions: contents: write + id-token: write steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 + fetch-depth: 0 # Fetch all history for all branches and tags - uses: actions/setup-python@v5 with: python-version: "3.11" - - name: Tag repository from pyproject - run: python scripts/check_version_match.py --fix + - run: pip install python-semantic-release + + - name: Configure Git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Run Semantic Release + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: semantic-release version + - name: Push changes + run: | + git push --follow-tags + - name: "Install Poetry" + run: pip install poetry + - name: "Determine version bump type" + run: | + git fetch --tags + # This defaults to a patch type, unless a feature commit was pushed, then set type to minor + LAST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1)) + LAST_COMMIT=$(git log -1 --format='%H') + echo "Last git tag: $LAST_TAG" + echo "Last git commit: $LAST_COMMIT" + echo "Commits:" + git log --no-merges --pretty=oneline $LAST_TAG...$LAST_COMMIT + git log --no-merges --pretty=format:"%s" $LAST_TAG...$LAST_COMMIT | grep -q ^feat: && BUMP_TYPE="minor" || BUMP_TYPE="patch" + echo "Version bump type: $BUMP_TYPE" + echo "BUMP_TYPE=$BUMP_TYPE" >> $GITHUB_ENV + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: "Version bump" + run: | + poetry version $BUMP_TYPE + - name: "Push new version" + run: | + git add pyproject.toml + git commit -m "Update version to $(poetry version -s)" + git pull --ff-only origin main + git push origin main --follow-tags + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 3ff5583..0000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,96 +0,0 @@ -name: CI - -on: - push: - branches: - - main - tags: - - 'v*' - pull_request: - branches: - - main - -jobs: - test: - name: Test with Python ${{ matrix.python-version }} - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12"] - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Cache Poetry dependencies - uses: actions/cache@v4 - with: - path: ~/.cache/pypoetry - key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} - - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH - - - name: Install dependencies - run: poetry install --with dev --no-interaction --no-root - - - name: Check runtime dependencies - run: | - poetry run python - <<'PY' - import importlib, sys - missing = [m for m in ("pandas",) if importlib.util.find_spec(m) is None] - if missing: - print("Missing dependencies:", ", ".join(missing)) - sys.exit(1) - PY - - - name: Verify version matches tag - if: startsWith(github.ref, 'refs/tags/') - run: python scripts/check_version_match.py - - - name: Run tests with coverage - run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term - - - name: Upload coverage to Codecov - if: matrix.python-version == '3.11' - uses: codecov/codecov-action@v5 - with: - files: coverage.xml - token: ${{ secrets.CODECOV_TOKEN }} - - lint: - name: Code Quality - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Cache Poetry dependencies - uses: actions/cache@v4 - with: - path: ~/.cache/pypoetry - key: ${{ runner.os }}-poetry-3.11-${{ hashFiles('**/poetry.lock') }} - - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH - - - name: Install dependencies - run: poetry install --with dev --no-interaction --no-root - - - name: Run pre-commit checks - run: poetry run pre-commit run --all-files diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml deleted file mode 100644 index 9bf55bb..0000000 --- a/.github/workflows/docs.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Documentation Build - -on: - push: - branches: [main] - pull_request: - branches: [main] - -jobs: - docs: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - name: Cache Poetry dependencies - uses: actions/cache@v4 - with: - path: ~/.cache/pypoetry - key: ${{ runner.os }}-poetry-3.11-${{ hashFiles('**/poetry.lock') }} - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH - - name: Install dependencies - run: poetry install --with docs - - name: Build documentation - run: poetry run sphinx-build -W -b html docs/source docs/build - - name: Upload artifacts - uses: actions/upload-artifact@v4 - with: - name: documentation - path: docs/build/ diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index a3928e0..0000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Publish to PyPI - -on: - workflow_dispatch: - -jobs: - publish: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - name: Verify version matches tag - run: python scripts/check_version_match.py - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH - - name: Publish - env: - POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} - run: poetry publish --build --no-interaction diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..bcd9349 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,38 @@ +name: Run Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.9" + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: poetry install + + - name: Run tests + run: poetry run pytest --cov=gen_surv --cov-report=xml --cov-report=term + + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + files: coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} # optional if public repo diff --git a/.gitignore b/.gitignore index 6144971..4518323 100644 --- a/.gitignore +++ b/.gitignore @@ -50,4 +50,3 @@ dist/ # Temporary *.log *.tmp -.hypothesis/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 9b7b387..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,19 +0,0 @@ -repos: - - repo: https://github.com/psf/black - rev: 24.1.0 - hooks: - - id: black - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 - hooks: - - id: mypy - pass_filenames: false - args: [--config-file=pyproject.toml, gen_surv] diff --git a/.readthedocs.yml b/.readthedocs.yml index a1f8a6e..a811ac3 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,22 +4,10 @@ build: os: ubuntu-22.04 tools: python: "3.11" - jobs: - post_create_environment: - - pip install poetry - post_install: - - poetry config virtualenvs.create false - - poetry install --with docs python: install: - - method: pip - path: . + - requirements: docs/requirements.txt sphinx: configuration: docs/source/conf.py - fail_on_warning: false - -formats: - - pdf - - epub diff --git a/CHANGELOG.md b/CHANGELOG.md index a8d40fb..500ee3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,40 +1,5 @@ # CHANGELOG -## v1.0.9 (2025-08-02) - -### Features -- export datasets to RDS files -- test workflow runs on a Python version matrix -- scikit-learn compatible data generator -- compatibility helpers for lifelines and scikit-survival - -### Documentation -- updated usage examples and tutorials -- document optional scikit-survival dependency throughout the docs - -### Continuous Integration -- auto-tag releases using the version check script - -### Misc -- README quick example uses `covariate_range` - -## v1.0.8 (2025-07-30) - -### Documentation -- ensure absolute path resolution in `conf.py` -- drop unsupported theme option -- define bibliography anchors and headings -- fix tutorial links to non-existing docs -- add additional references to the bibliography - -### Testing -- add CLI integration test -- expand piecewise generator test coverage - -### Misc -- remove fix_recommendations.md - - ## v1.0.0 (2025-06-06) diff --git a/CHECKLIST.md b/CHECKLIST.md deleted file mode 100644 index 0163f08..0000000 --- a/CHECKLIST.md +++ /dev/null @@ -1,102 +0,0 @@ -# βœ… Python Package Development Checklist - -A checklist to ensure quality, maintainability, and usability of your Python package. - ---- - -## 1. Purpose & Scope - -- [ ] Clear purpose and use cases defined -- [ ] Scoped to a specific problem/domain -- [ ] Project name is meaningful and not taken on PyPI - ---- - -## 2. Project Structure - -- [ ] Uses `src/` layout or appropriate flat structure -- [ ] All package folders contain `__init__.py` -- [ ] Main configuration handled via `pyproject.toml` -- [ ] Includes standard files: `README.md`, `LICENSE`, `.gitignore`, `CHANGELOG.md` - ---- - -## 3. Dependencies - -- [ ] All dependencies declared in `pyproject.toml` or `requirements.txt` -- [ ] Development dependencies separated from runtime dependencies -- [ ] Uses minimal, necessary dependencies only - ---- - -## 4. Code Quality - -- [ ] Follows PEP 8 formatting -- [ ] Imports sorted with `isort` or `ruff` -- [ ] No linter warnings (`ruff`, `flake8`, etc.) -- [ ] Fully typed using `typing` module -- [ ] No unresolved TODOs or FIXME comments - ---- - -## 5. Function & Module Design - -- [ ] Functions are small, pure, and single-responsibility -- [ ] Classes follow clear and simple roles -- [ ] Global state is avoided -- [ ] Public API defined explicitly (e.g. via `__all__`) - ---- - -## 6. Documentation - -- [ ] `README.md` includes overview, install, usage, contributing -- [ ] All functions/classes include docstrings (Google/NumPy style) -- [ ] API reference documentation auto-generated (e.g., Sphinx, MkDocs) -- [ ] Optional: `docs/` folder for additional documentation or site generator - ---- - -## 7. Testing - -- [ ] Unit and integration tests implemented -- [ ] Test coverage > 80% verified by `coverage` -- [ ] Tests are fast and deterministic -- [ ] Continuous Integration (CI) configured to run tests - ---- - -## 8. Versioning & Releases - -- [ ] Uses semantic versioning (MAJOR.MINOR.PATCH) -- [ ] Git tags created for releases -- [ ] `CHANGELOG.md` updated with each release -- [ ] Local build verified (`poetry build`, `hatch build`, or equivalent) -- [ ] Can be published to PyPI and/or TestPyPI - ---- - -## 9. CLI or Scripts (Optional) - -- [ ] CLI entrypoint works correctly (`__main__.py` or `entry_points`) -- [ ] CLI provides helpful messages (`--help`) and handles errors gracefully - ---- - -## 10. Examples / Tutorials - -- [ ] Usage examples included in `README.md` or `examples/` -- [ ] Optional: Jupyter notebooks with demonstrations -- [ ] Optional: Colab or Binder links for live usage - ---- - -## 11. Licensing & Attribution - -- [ ] LICENSE file included (MIT, Apache 2.0, GPL, etc.) -- [ ] Author and contributors credited in `README.md` -- [ ] Optional: `CITATION.cff` file for academic citation - ---- - -> You can duplicate this file for each new package or use it as a GitHub issue template for release checklists. diff --git a/CITATION.cff b/CITATION.cff index 78d9c81..fe3f8e0 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,15 +5,14 @@ message: "If you use this software, please cite it using the metadata below." preferred-citation: type: software title: "gen_surv" - version: "1.0.9" + version: "1.0.3" url: "https://github.com/DiogoRibeiro7/genSurvPy" authors: - family-names: Ribeiro given-names: Diogo - alias: DiogoRibeiro7 orcid: "https://orcid.org/0009-0001-2022-7072" affiliation: "ESMAD - Instituto PolitΓ©cnico do Porto" email: "dfr@esmad.ipp.pt" license: "MIT" - date-released: "2025-08-02" + date-released: "2024-01-01" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index aba48fc..07dc4b6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,14 +5,9 @@ Thank you for taking the time to contribute to **gen_surv**! This document provi ## Getting Started 1. Fork the repository and create your feature branch from `main`. -2. Install dependencies with `poetry install --with dev`. - This installs all packages needed for development, including - the optional dependency `scikit-survival`. - On Debian/Ubuntu you may need `build-essential gfortran libopenblas-dev` - to build it. -3. Run `pre-commit install` to enable style checks and execute them with `pre-commit run --all-files`. -4. Ensure the test suite passes with `poetry run pytest`. -5. If you add a feature or fix a bug, update `CHANGELOG.md` accordingly. +2. Install dependencies with `poetry install`. +3. Ensure the test suite passes with `poetry run pytest`. +4. If you add a feature or fix a bug, update `CHANGELOG.md` accordingly. ## Version Consistency diff --git a/LICENSE b/LICENCE similarity index 100% rename from LICENSE rename to LICENCE diff --git a/README.md b/README.md index df4beab..b74802b 100644 --- a/README.md +++ b/README.md @@ -1,142 +1,157 @@ # gen_surv -[![Coverage][cov-badge]][cov-link] -[![Docs][docs-badge]][docs-link] -[![PyPI][pypi-badge]][pypi-link] -[![Tests][ci-badge]][ci-link] -[![Python][py-badge]][pypi-link] - -[cov-badge]: https://codecov.io/gh/DiogoRibeiro7/genSurvPy/branch/main/graph/badge.svg -[cov-link]: https://app.codecov.io/gh/DiogoRibeiro7/genSurvPy -[docs-badge]: https://readthedocs.org/projects/gensurvpy/badge/?version=latest -[docs-link]: https://gensurvpy.readthedocs.io/en/latest/ -[pypi-badge]: https://img.shields.io/pypi/v/gen_surv -[pypi-link]: https://pypi.org/project/gen-surv/ -[ci-badge]: https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/ci.yml/badge.svg -[ci-link]: https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/ci.yml -[py-badge]: https://img.shields.io/pypi/pyversions/gen_surv - -**gen_surv** is a Python library for simulating survival data under a wide range of statistical models. Inspired by the R package [genSurv](https://cran.r-project.org/package=genSurv), it offers a unified interface for generating realistic datasets for research, teaching and benchmarking. +![Coverage](https://codecov.io/gh/DiogoRibeiro7/genSurvPy/branch/main/graph/badge.svg) +[![Docs](https://readthedocs.org/projects/gensurvpy/badge/?version=stable)](https://gensurvpy.readthedocs.io/en/stable/) +![PyPI](https://img.shields.io/pypi/v/gen_surv) +![Tests](https://github.com/DiogoRibeiro7/genSurvPy/actions/workflows/test.yml/badge.svg) +![Python](https://img.shields.io/pypi/pyversions/gen_surv) ---- - -## Features -- Cox proportional hazards model (CPHM) -- Accelerated failure time models (log-normal, log-logistic, Weibull) -- Continuous-time multi-state Markov model (CMM) -- Time-dependent covariate model (TDCM) -- Time-homogeneous hidden Markov model (THMM) -- Mixture cure and piecewise exponential models -- Competing risks generators (constant and Weibull hazards) -- Visualization utilities for simulated datasets -- Scikit-learn compatible data generator -- Conversion helpers for scikit-survival and lifelines -- Command-line interface and export utilities +**gen_surv** is a Python package for simulating survival data under a variety of models, inspired by the R package [`genSurv`](https://cran.r-project.org/package=genSurv). It supports data generation for: -## Installation +- Cox Proportional Hazards Models (CPHM) +- Continuous-Time Markov Models (CMM) +- Time-Dependent Covariate Models (TDCM) +- Time-Homogeneous Hidden Markov Models (THMM) -Requires Python 3.10 or later. +--- -Install the latest release from PyPI: +## πŸ“¦ Installation ```bash -pip install gen-surv +poetry install ``` +## ✨ Features -To develop locally with all extras: - -```bash -git clone https://github.com/DiogoRibeiro7/genSurvPy.git -cd genSurvPy -# Install runtime and development dependencies -# (scikit-survival is optional but required for integration tests). -# On Debian/Ubuntu you may need `build-essential gfortran libopenblas-dev` to -# build scikit-survival. -poetry install --with dev -``` +- Consistent interface across models +- Censoring support (`uniform` or `exponential`) +- Easy integration with `pandas` and `NumPy` +- Suitable for benchmarking survival algorithms and teaching +- Accelerated Failure Time (Log-Normal) model generator +- Command-line interface powered by `Typer` -Integration tests that rely on scikit-survival are automatically skipped if the package is not installed. +## πŸ§ͺ Example -## Development Setup +```python +from gen_surv import generate -Before committing changes, install the pre-commit hooks: +# CPHM +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) -```bash -pre-commit install -pre-commit run --all-files -``` +# AFT Log-Normal +generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) -## Usage +# CMM +generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, + qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) -### Python API +# TDCM +generate(model="tdcm", n=100, dist="weibull", corr=0.5, + dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, + beta=[0.1, 0.2, 0.3], lam=1.0) -```python -from gen_surv import generate, export_dataset, to_sksurv -from gen_surv.visualization import plot_survival_curve - -# basic Cox proportional hazards data -sim = generate( - model="cphm", - n=100, - beta=0.5, - covariate_range=2.0, - model_cens="uniform", - cens_par=1.0, -) - -plot_survival_curve(sim) -export_dataset(sim, "survival_data.rds") - -# convert for scikit-survival -sks_dataset = to_sksurv(sim) +# THMM +generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], + emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, + p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) ``` -See the [usage guide](https://gensurvpy.readthedocs.io/en/latest/getting_started.html) for more examples. +## ⌨️ Command-Line Usage -### Command Line - -Datasets can be generated without writing Python code: +Install the package and use ``python -m gen_surv`` to generate datasets without +writing Python code: ```bash -python -m gen_surv dataset cphm --n 1000 -o survival.csv +python -m gen_surv dataset aft_ln --n 100 > data.csv ``` -## Supported Models +## πŸ”§ API Overview + +| Function | Description | +|----------|-------------| +| `generate()` | Unified interface that calls any generator | +| `gen_cphm()` | Cox Proportional Hazards Model | +| `gen_cmm()` | Continuous-Time Multi-State Markov Model | +| `gen_tdcm()` | Time-Dependent Covariate Model | +| `gen_thmm()` | Time-Homogeneous Markov Model | +| `gen_aft_log_normal()` | Accelerated Failure Time Log-Normal | +| `sample_bivariate_distribution()` | Sample correlated Weibull or exponential times | +| `runifcens()` | Generate uniform censoring times | +| `rexpocens()` | Generate exponential censoring times | + + +```text +genSurvPy/ +β”œβ”€β”€ gen_surv/ # Pacote principal +β”‚ β”œβ”€β”€ __main__.py # Interface CLI via python -m +β”‚ β”œβ”€β”€ cphm.py +β”‚ β”œβ”€β”€ cmm.py +β”‚ β”œβ”€β”€ tdcm.py +β”‚ β”œβ”€β”€ thmm.py +β”‚ β”œβ”€β”€ censoring.py +β”‚ β”œβ”€β”€ bivariate.py +β”‚ β”œβ”€β”€ validate.py +β”‚ └── interface.py +β”œβ”€β”€ tests/ # Testes automatizados +β”‚ β”œβ”€β”€ test_cphm.py +β”‚ β”œβ”€β”€ test_cmm.py +β”‚ β”œβ”€β”€ test_tdcm.py +β”‚ β”œβ”€β”€ test_thmm.py +β”œβ”€β”€ examples/ # Exemplos de uso +β”‚ β”œβ”€β”€ run_aft.py +β”‚ β”œβ”€β”€ run_cmm.py +β”‚ β”œβ”€β”€ run_cphm.py +β”‚ β”œβ”€β”€ run_tdcm.py +β”‚ └── run_thmm.py +β”œβ”€β”€ docs/ # DocumentaΓ§Γ£o Sphinx +β”‚ β”œβ”€β”€ source/ +β”‚ └── ... +β”œβ”€β”€ scripts/ # Utilidades diversas +β”‚ └── check_version_match.py +β”œβ”€β”€ tasks.py # Tarefas automatizadas com Invoke +β”œβ”€β”€ TODO.md # Roadmap de desenvolvimento +β”œβ”€β”€ pyproject.toml # Configurado com Poetry +β”œβ”€β”€ README.md +β”œβ”€β”€ LICENCE +└── .gitignore +``` -| Model | Description | -|-------|-------------| -| **CPHM** | Cox proportional hazards | -| **AFT** | Accelerated failure time (log-normal, log-logistic, Weibull) | -| **CMM** | Continuous-time multi-state Markov | -| **TDCM** | Time-dependent covariates | -| **THMM** | Time-homogeneous hidden Markov | -| **Competing Risks** | Multiple event types with cause-specific hazards | -| **Mixture Cure** | Models long-term survivors | -| **Piecewise Exponential** | Flexible baseline hazard via intervals | +## 🧠 License -More details on each algorithm are available in the [Algorithms](https://gensurvpy.readthedocs.io/en/latest/algorithms.html) page. For additional background, see the [theory guide](https://gensurvpy.readthedocs.io/en/latest/theory.html). +MIT License. See [LICENCE](LICENCE) for details. -## Documentation -Full documentation is hosted on [Read the Docs](https://gensurvpy.readthedocs.io/en/latest/). It includes installation instructions, tutorials, API references and a bibliography. +## πŸ”– Release Process -To build the docs locally: +This project uses Git tags to manage releases. A GitHub Actions workflow +(`version-check.yml`) verifies that the version declared in `pyproject.toml` +matches the latest Git tag. If they diverge, the workflow fails and prompts a +correction before merging. Run `python scripts/check_version_match.py` locally +before creating a tag to catch issues early. -```bash -cd docs -make html -``` +## 🌟 Code of Conduct + +Please read our [Code of Conduct](CODE_OF_CONDUCT.md) to learn about the +expectations for participants in this project. -Open `build/html/index.html` in your browser to view the result. +## 🀝 Contributing -## License +Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on setting up your environment, running tests, and submitting pull requests. -This project is licensed under the MIT License. See [LICENSE](LICENSE) for details. +## πŸ”§ Development Tasks + +Common project commands are defined in [`tasks.py`](tasks.py) and can be executed with [Invoke](https://www.pyinvoke.org/): + +```bash +poetry run inv -l # list available tasks +poetry run inv test # run the test suite +``` -## Citation +## πŸ“‘ Citation -If you use **gen_surv** in your research, please cite the project using the metadata in [CITATION.cff](CITATION.cff). +If you use **gen_surv** in your work, please cite it using the metadata in +[`CITATION.cff`](CITATION.cff). Many reference managers can import this file +directly. ## Author diff --git a/TODO.md b/TODO.md index eb15799..91b4fe4 100644 --- a/TODO.md +++ b/TODO.md @@ -1,65 +1,103 @@ -# gen_surv Roadmap +# TODO – Roadmap for gen_surv -This document outlines the planned development priorities for future versions of gen_surv. This roadmap will be periodically updated based on user feedback, research needs, and community contributions. +This document outlines future enhancements, features, and ideas for improving the gen_surv package. -## Short-term Goals (v1.1.x) +--- -### Additional Statistical Models -- [ ] **Recurrent Events Model**: Generate data with multiple events per subject -- [ ] **Time-Varying Effects**: Support for non-proportional hazards with coefficients that change over time -- [ ] **Extended Competing Risks**: Allow for correlation between competing risks +## ✨ Priority Items -### Visualization and Analysis -- [ ] **Enhanced Visualization Toolkit**: Add more plot types and customization options -- [ ] **Interactive Visualizations**: Add options using Plotly for interactive exploration -- [ ] **Data Quality Reports**: Generate reports on statistical properties of generated datasets +- [βœ…] Add property-based tests using Hypothesis to cover edge cases +- [βœ…] Build a CLI for generating datasets from the terminal +- [ ] Expand documentation with multilingual support and more usage examples +- [ ] Implement Weibull and log-logistic AFT models and add visualization utilities +- [βœ…] Provide CITATION metadata for proper referencing +- [ ] Ensure all functions include Google-style docstrings with inline comments -### Usability Improvements -- [ ] **Dataset Catalog**: Pre-configured parameters to mimic classic survival datasets -- [ ] **Parameter Estimation**: Tools to estimate generation parameters from existing datasets -- [ ] **Extended CLI**: Add more command-line options for all models +--- -## Medium-term Goals (v1.2.x) +## πŸ“¦ 1. Interface and UX -### Advanced Statistical Models -- [ ] **Joint Longitudinal-Survival Models**: Generators for models that simultaneously handle longitudinal outcomes and time-to-event data -- [ ] **Frailty Models**: Support for shared and nested frailty models -- [ ] **Interval Censoring**: Support for interval-censored data generation +- [βœ…] Create a `generate(..., return_type="df" | "dict")` interface +- [βœ…] Add `__version__` using `importlib.metadata` or `poetry-dynamic-versioning` +- [βœ…] Build a CLI with `typer` or `click` +- [βœ…] Add example notebooks or scripts for each model (`examples/` folder) -### Technical Enhancements -- [ ] **Parallel Processing**: Multi-core support for faster generation of large datasets -- [ ] **Memory Optimization**: Streaming data generation for very large datasets -- [ ] **Performance Benchmarks**: Systematic benchmarking of data generation speed +--- -### Integration and Ecosystem -- [ ] **scikit-learn Extensions**: More scikit-learn compatible estimators and transformers -- [ ] **Stan/PyMC Integration**: Export data in formats suitable for Bayesian modeling -- [ ] **Dashboard**: Simple Streamlit app for data exploration and generation +## πŸ“š 2. Documentation -## Long-term Goals (v2.x) +- [βœ…] Add a "Model Comparison Guide" section (`index.md` + `theory.md`) +- [βœ…] Add "How It Works" sections for each model (`theory.md`) +- [βœ…] Include usage examples in index with real calls +- [ ] Optional: add multilingual docs using `sphinx-intl` -### Advanced Features -- [ ] **Bayesian Survival Models**: Generators for Bayesian survival analysis with various priors -- [ ] **Spatial Survival Models**: Generate survival data with spatial correlation -- [ ] **Survival Neural Networks**: Integration with deep learning approaches to survival analysis +--- -### Infrastructure and Performance -- [ ] **GPU Acceleration**: Optional GPU support for large dataset generation -- [ ] **JAX/Numba Implementation**: High-performance implementations of key algorithms -- [ ] **R Interface**: Create an R package that interfaces with gen_surv +## πŸ§ͺ 3. Testing and Quality -### Community and Documentation -- [ ] **Interactive Tutorials**: Using Jupyter Book or similar tools -- [ ] **Video Tutorials**: Short video demonstrations of key features -- [ ] **Case Studies**: Real-world examples showing how gen_surv can be used for teaching or research -- [ ] **User Showcase**: Gallery of research or teaching that uses gen_surv +- [βœ…] Add tests for each model (e.g., `test_tdcm.py`, `test_thmm.py`, `test_aft.py`) +- [βœ…] Add property-based tests with `hypothesis` +- [ ] Cover edge cases (e.g., invalid parameters, n=0, negative censoring) +- [ ] Run tests on multiple Python versions (CI matrix) -## How to Contribute +--- -We welcome contributions that help us achieve these roadmap goals! If you're interested in working on any of these features, please check the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines and open an issue to discuss your approach before submitting a pull request. +## 🧠 4. Advanced Models -For suggesting new features or modifications to this roadmap, please open an issue with the "enhancement" tag. +- [ ] Add Piecewise Exponential Model support +- [ ] Add competing risks / multi-event simulation +- [βœ…] Implement parametric AFT models (log-normal) +- [ ] Implement parametric AFT models (log-logistic, weibull) +- [ ] Simulate time-varying hazards +- [ ] Add informative or covariate-dependent censoring -## Version History +--- -For a detailed history of past releases, please see our [CHANGELOG.md](CHANGELOG.md). +## πŸ“Š 5. Visualization and Analysis + +- [ ] Create `plot_survival(df, model=...)` utilities +- [ ] Create `describe_survival(df)` summary helpers +- [ ] Export data to CSV / JSON / Feather + +--- + +## 🌍 6. Ecosystem Integration + +- [ ] Add a `GenSurvDataGenerator` compatible with `sklearn` +- [ ] Enable use with `lifelines`, `scikit-survival`, `sksurv` +- [ ] Export in R-compatible formats (.csv, .rds) + +--- + +## πŸ” 7. Other Ideas + +- [ ] Add performance benchmarks for each model +- [βœ…] Improve PyPI discoverability (added tags, keywords, docs) +- [ ] Create a Streamlit or Gradio live demo + +--- + +## 🧠 8. New Survival Models to Implement + +- [βœ…] Log-Normal AFT +- [ ] Log-Logistic AFT +- [ ] Weibull AFT +- [ ] Piecewise Exponential +- [ ] Competing Risks +- [ ] Recurrent Events +- [ ] Mixture Cure Model + +--- + +## 🧬 9. Advanced Data Simulation Features + +- [ ] Recurrent events (multiple events per individual) +- [ ] Frailty models (random effects) +- [ ] Time-varying hazard functions +- [ ] Multi-line start-stop formatted data +- [ ] Competing risks with cause-specific hazards +- [ ] Simulate violations of PH assumption +- [ ] Grouped / clustered data generation +- [ ] Mixed covariates: categorical, continuous, binary +- [ ] Joint models (longitudinal + survival outcome) +- [ ] Controlled scenarios for robustness tests diff --git a/benchmarks/README.md b/benchmarks/README.md deleted file mode 100644 index 37e9c19..0000000 --- a/benchmarks/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Benchmarks - -This directory contains performance benchmarks run with `pytest-benchmark`. -Run them with: - -```bash -pytest benchmarks -q --benchmark-only -``` - -## Available benchmarks - -- validation helpers -- TDCM generation diff --git a/benchmarks/test_tdcm_benchmark.py b/benchmarks/test_tdcm_benchmark.py deleted file mode 100644 index 9509548..0000000 --- a/benchmarks/test_tdcm_benchmark.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -pytest.importorskip("pytest_benchmark") - -from gen_surv.tdcm import gen_tdcm - - -def test_tdcm_generation_benchmark(benchmark): - benchmark( - gen_tdcm, - n=1000, - dist="weibull", - corr=0.5, - dist_par=[1, 2, 1, 2], - model_cens="uniform", - cens_par=1.0, - beta=[0.1, 0.2, 0.3], - lam=1.0, - ) diff --git a/benchmarks/test_validation_benchmark.py b/benchmarks/test_validation_benchmark.py deleted file mode 100644 index 7cf50f9..0000000 --- a/benchmarks/test_validation_benchmark.py +++ /dev/null @@ -1,11 +0,0 @@ -import numpy as np -import pytest - -pytest.importorskip("pytest_benchmark") - -from gen_surv.validation import ensure_positive_sequence - - -def test_positive_sequence_benchmark(benchmark): - seq = np.random.rand(10000) + 1.0 - benchmark(ensure_positive_sequence, seq, "seq") diff --git a/binder/requirements.txt b/binder/requirements.txt deleted file mode 100644 index 2df98f0..0000000 --- a/binder/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ --e . -jupyterlab diff --git a/docs/requirements.txt b/docs/requirements.txt index a0990cf..98a3c62 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,8 +1,2 @@ -sphinx>=6.0 -myst-parser>=1.0.0,<4.0.0 -sphinx-rtd-theme -sphinx-autodoc-typehints -sphinx-copybutton -sphinx-design -linkify-it-py>=2.0 -matplotlib +sphinx +myst-parser diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css deleted file mode 100644 index 22f503a..0000000 --- a/docs/source/_static/custom.css +++ /dev/null @@ -1,56 +0,0 @@ -/* Custom styling for gen_surv documentation */ - -/* Improve code block appearance */ -.highlight { - background: #f8f9fa; - border: 1px solid #e9ecef; - border-radius: 4px; - padding: 10px; - margin: 10px 0; -} - -/* Style admonitions */ -.admonition { - border-radius: 6px; - margin: 1em 0; - padding: 1em; -} - -.admonition.tip { - border-left: 4px solid #28a745; - background-color: #d4edda; -} - -.admonition.note { - border-left: 4px solid #007bff; - background-color: #cce5ff; -} - -.admonition.warning { - border-left: 4px solid #ffc107; - background-color: #fff3cd; -} - -/* Table styling */ -table.docutils { - border-collapse: collapse; - margin: 1em 0; - width: 100%; -} - -table.docutils th, -table.docutils td { - border: 1px solid #ddd; - padding: 8px; - text-align: left; -} - -table.docutils th { - background-color: #f8f9fa; - font-weight: bold; -} - -/* Math styling */ -.math { - font-size: 1.1em; -} diff --git a/docs/source/algorithms.md b/docs/source/algorithms.md deleted file mode 100644 index d5bcb70..0000000 --- a/docs/source/algorithms.md +++ /dev/null @@ -1,48 +0,0 @@ ---- -orphan: true ---- - -# Algorithm Overview - -This page provides a short description of each model implemented in **gen_surv**. For mathematical details see {doc}`theory`. - -## Cox Proportional Hazards Model (CPHM) -The hazard at time $t$ is proportional to a baseline hazard multiplied by the exponential of covariate effects. -It is widely used for modelling relative risks under the proportional hazards assumption. -See {ref}`Cox1972` in the {doc}`bibliography` for the seminal paper. - -## Accelerated Failure Time Models (AFT) -These parametric models directly relate covariates to survival time. -gen_surv includes log-normal, log-logistic and Weibull variants allowing different baseline distributions. -They are convenient when the effect of covariates accelerates or decelerates event times. - -## Continuous-Time Multi-State Markov Model (CMM) -Transitions between states are governed by a generator matrix. -This model is suited for illness-death and other multi-state processes where state occupancy changes continuously over time. -The mathematical formulation follows the counting-process approach of Andersen et al. {ref}`Andersen1993`. - -## Time-Dependent Covariate Model (TDCM) -Extends the Cox model to covariates that vary during follow-up. -Covariates are simulated in a piecewise fashion with optional correlation across segments. - -## Time-Homogeneous Hidden Markov Model (THMM) -Handles processes with unobserved states that emit observable values. -The latent transitions follow a homogeneous Markov chain while emissions are Gaussian. -For background on these models see Zucchini et al. {ref}`Zucchini2017`. - -## Competing Risks -Allows multiple failure types with cause-specific hazards. -gen_surv supports constant and Weibull hazards for each cause. -The subdistribution approach of Fine and Gray {ref}`FineGray1999` is commonly used for analysis. - -## Mixture Cure Model -Assumes a proportion of individuals will never experience the event. -A logistic component determines who is cured, while uncured subjects follow an exponential failure distribution. -Mixture cure models were introduced by Farewell {ref}`Farewell1982`. - -## Piecewise Exponential Model -Approximates complex hazard shapes by dividing follow-up time into intervals with constant hazard within each interval. -This yields a flexible baseline hazard while remaining computationally simple. - -For additional reading on these methods please see the {doc}`bibliography`. - diff --git a/docs/source/api/index.md b/docs/source/api/index.md deleted file mode 100644 index dc3c722..0000000 --- a/docs/source/api/index.md +++ /dev/null @@ -1,89 +0,0 @@ ---- -orphan: true ---- - -# API Reference - -Complete documentation for all gen_surv functions and classes. - -```{note} -The `to_sksurv` helper and related tests rely on the optional -dependency `scikit-survival`. Install it with `poetry install --with dev` -or `pip install scikit-survival` to enable this functionality. -``` - -## Core Interface - -::: gen_surv.interface - options: - members: true - undoc-members: true - show-inheritance: true - -## Model Generators - -### Cox Proportional Hazards Model -::: gen_surv.cphm - options: - members: true - undoc-members: true - show-inheritance: true - -### Accelerated Failure Time Models -::: gen_surv.aft - options: - members: true - undoc-members: true - show-inheritance: true - -### Continuous-Time Markov Models -::: gen_surv.cmm - options: - members: true - undoc-members: true - show-inheritance: true - -### Time-Dependent Covariate Models -::: gen_surv.tdcm - options: - members: true - undoc-members: true - show-inheritance: true - -### Time-Homogeneous Markov Models -::: gen_surv.thmm - options: - members: true - undoc-members: true - show-inheritance: true - -## Utility Functions - -### Censoring Functions -::: gen_surv.censoring - options: - members: true - undoc-members: true - show-inheritance: true - -### Bivariate Distributions -::: gen_surv.bivariate - options: - members: true - undoc-members: true - show-inheritance: true - -### Validation Functions -::: gen_surv.validation - options: - members: true - undoc-members: true - show-inheritance: true - -### Command Line Interface -::: gen_surv.cli - options: - members: true - undoc-members: true - show-inheritance: true - diff --git a/docs/source/bibliography.md b/docs/source/bibliography.md deleted file mode 100644 index 5a5285a..0000000 --- a/docs/source/bibliography.md +++ /dev/null @@ -1,59 +0,0 @@ ---- -orphan: true ---- - -# References - -Below is a selection of references covering the statistical models implemented in **gen_surv**. - -(Cox1972)= -## Cox (1972) -Cox, D. R. (1972). Regression Models and Life-Tables. *Journal of the Royal Statistical Society: Series B*, 34(2), 187-220. - -(Farewell1982)= -## Farewell (1982) -Farewell, V.T. (1982). The Use of Mixture Models for the Analysis of Survival Data with Long-Term Survivors. *Biometrics*, 38(4), 1041-1046. - -(FineGray1999)= -## Fine and Gray (1999) -Fine, J.P., & Gray, R.J. (1999). A Proportional Hazards Model for the Subdistribution of a Competing Risk. *Journal of the American Statistical Association*, 94(446), 496-509. - -(Andersen1993)= -## Andersen et al. (1993) -Andersen, P.K., Borgan, Ø., Gill, R.D., & Keiding, N. (1993). *Statistical Models Based on Counting Processes*. Springer. - -(Zucchini2017)= -## Zucchini et al. (2017) -Zucchini, W., MacDonald, I.L., & Langrock, R. (2017). *Hidden Markov Models for Time Series*. Chapman and Hall/CRC. - -(KleinMoeschberger2003)= -## Klein and Moeschberger (2003) -Klein, J.P., & Moeschberger, M.L. (2003). *Survival Analysis: Techniques for Censored and Truncated Data*. Springer. - -(KalbfleischPrentice2002)= -## Kalbfleisch and Prentice (2002) -Kalbfleisch, J.D., & Prentice, R.L. (2002). *The Statistical Analysis of Failure Time Data*. Wiley. - -(CookLawless2007)= -## Cook and Lawless (2007) -Cook, R.J., & Lawless, J.F. (2007). *The Statistical Analysis of Recurrent Events*. Springer. - -(KaplanMeier1958)= -## Kaplan and Meier (1958) -Kaplan, E.L., & Meier, P. (1958). Nonparametric Estimation from Incomplete Observations. *Journal of the American Statistical Association*, 53(282), 457-481. -(TherneauGrambsch2000)= -## Therneau and Grambsch (2000) -Therneau, T.M., & Grambsch, P.M. (2000). *Modeling Survival Data: Extending the Cox Model*. Springer. - -(FlemingHarrington1991)= -## Fleming and Harrington (1991) -Fleming, T.R., & Harrington, D.P. (1991). *Counting Processes and Survival Analysis*. Wiley. - -(Collett2015)= -## Collett (2015) -Collett, D. (2015). *Modelling Survival Data in Medical Research*. CRC Press. - -(KleinbaumKlein2012)= -## Kleinbaum and Klein (2012) -Kleinbaum, D.G., & Klein, M. (2012). *Survival Analysis: A Self-Learning Text*. Springer. - diff --git a/docs/source/changelog.md b/docs/source/changelog.md deleted file mode 100644 index 011713d..0000000 --- a/docs/source/changelog.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -orphan: true ---- - -# Changelog - -```{include} ../../CHANGELOG.md -:relative-docs: true -``` diff --git a/docs/source/conf.py b/docs/source/conf.py index df8af80..6ac6534 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,92 +1,42 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os -from datetime import datetime -from importlib import metadata +import sys +sys.path.insert(0, os.path.abspath('../../gen_surv')) + +project = 'gen_surv' +copyright = '2025, Diogo Ribeiro' +author = 'Diogo Ribeiro' +release = '1.0.3' -# Project information -project = "gen_surv" -copyright = f"{datetime.now().year}, Diogo Ribeiro" -author = "Diogo Ribeiro" -release = metadata.version("gen_surv") -version = release +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -# General configuration extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", - "sphinx.ext.viewcode", - "sphinx.ext.intersphinx", - "sphinx.ext.autosummary", - "sphinx.ext.githubpages", "myst_parser", - "sphinx_copybutton", - "sphinx_design", - "sphinx_autodoc_typehints", -] - -# MyST Parser configuration -myst_enable_extensions = [ - "colon_fence", - "deflist", - "html_admonition", - "html_image", - "linkify", - "replacements", - "smartquotes", - "substitution", - "tasklist", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", ] -# Autodoc configuration -autodoc_default_options = { - "members": True, - "member-order": "bysource", - "special-members": "__init__", - "undoc-members": True, - "exclude-members": "__weakref__", -} - -# Autosummary -autosummary_generate = True - -# Napoleon settings -napoleon_google_docstring = True -napoleon_numpy_docstring = True -napoleon_include_init_with_doc = False -napoleon_include_private_with_doc = False +autosectionlabel_prefix_document = True -# Intersphinx mapping -intersphinx_mapping = { - "python": ("https://docs.python.org/3/", None), - "numpy": ("https://numpy.org/doc/stable/", None), - "pandas": ("https://pandas.pydata.org/docs/", None), -} +# Point to index.md or index.rst as the root document +master_doc = "index" -# Disable fetching remote inventories when network access is unavailable -if os.environ.get("SKIP_INTERSPHINX", "1") == "1": - intersphinx_mapping = {} +templates_path = ['_templates'] +exclude_patterns = [] -# HTML theme options -html_theme = "sphinx_rtd_theme" -html_theme_options = { - "canonical_url": "https://gensurvpy.readthedocs.io/", - "analytics_id": "", - "logo_only": False, - "prev_next_buttons_location": "bottom", - "style_external_links": False, - "style_nav_header_background": "#2980B9", - "collapse_navigation": True, - "sticky_navigation": True, - "navigation_depth": 4, - "includehidden": True, - "titles_only": False, -} +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_static_path = ["_static"] -html_css_files = ["custom.css"] +html_theme = 'alabaster' +html_static_path = ['_static'] -# Output file base name for HTML help builder -htmlhelp_basename = "gensurvdoc" -# Copy button configuration -copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " -copybutton_prompt_is_regexp = True diff --git a/docs/source/contributing.md b/docs/source/contributing.md deleted file mode 100644 index 82121b3..0000000 --- a/docs/source/contributing.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -orphan: true ---- - -# Contributing - -```{include} ../../CONTRIBUTING.md -:relative-docs: true -``` diff --git a/docs/source/examples/cmm.md b/docs/source/examples/cmm.md deleted file mode 100644 index a389544..0000000 --- a/docs/source/examples/cmm.md +++ /dev/null @@ -1,28 +0,0 @@ -# Continuous-Time Multi-State Markov Model (CMM) - -[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/cmm.ipynb) - -Visualize transition times from the CMM generator: - -```python -import numpy as np -import matplotlib.pyplot as plt -from gen_surv import generate - -np.random.seed(0) - -df = generate( - model="cmm", - n=200, - model_cens="exponential", - cens_par=2.0, - beta=[0.1, 0.2, 0.3], - covariate_range=1.0, - rate=[0.1, 1.0, 0.2, 1.0, 0.1, 1.0], -) - -plt.hist(df["stop"], bins=20, color="#4C72B0") -plt.xlabel("Time") -plt.ylabel("Frequency") -plt.title("CMM Transition Times") -``` diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md deleted file mode 100644 index 14003cf..0000000 --- a/docs/source/examples/index.md +++ /dev/null @@ -1,25 +0,0 @@ ---- -orphan: true ---- - -# Examples - -Real-world examples and use cases for gen_surv. - -```{note} -Some examples may require optional packages such as -`scikit-survival`. Install them with `poetry install --with dev` or -`pip install scikit-survival` before running these examples. -``` - -Each example page includes a [Binder](https://mybinder.org/) badge so you can -launch the corresponding notebook and experiment interactively in your -browser. - -```{toctree} -:maxdepth: 2 - -tdcm -cmm -thmm -``` diff --git a/docs/source/examples/tdcm.md b/docs/source/examples/tdcm.md deleted file mode 100644 index cc027af..0000000 --- a/docs/source/examples/tdcm.md +++ /dev/null @@ -1,30 +0,0 @@ -# Time-Dependent Covariate Model (TDCM) - -[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/tdcm.ipynb) - -A basic visualization of event times produced by the TDCM generator: - -```python -import numpy as np -import matplotlib.pyplot as plt -from gen_surv import generate - -np.random.seed(0) - -df = generate( - model="tdcm", - n=200, - dist="weibull", - corr=0.5, - dist_par=[1, 2, 1, 2], - model_cens="uniform", - cens_par=1.0, - beta=[0.1, 0.2, 0.3], - lam=1.0, -) - -plt.hist(df["stop"], bins=20, color="#4C72B0") -plt.xlabel("Time") -plt.ylabel("Frequency") -plt.title("TDCM Event Times") -``` diff --git a/docs/source/examples/thmm.md b/docs/source/examples/thmm.md deleted file mode 100644 index 01fb17f..0000000 --- a/docs/source/examples/thmm.md +++ /dev/null @@ -1,28 +0,0 @@ -# Time-Homogeneous Hidden Markov Model (THMM) - -[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/DiogoRibeiro7/genSurvPy/HEAD?urlpath=lab/tree/examples/notebooks/thmm.ipynb) - -An example of event times generated by the THMM: - -```python -import numpy as np -import matplotlib.pyplot as plt -from gen_surv import generate - -np.random.seed(0) - -df = generate( - model="thmm", - n=200, - model_cens="exponential", - cens_par=3.0, - beta=[0.1, 0.2, 0.3], - covariate_range=1.0, - rate=[0.2, 0.1, 0.3], -) - -plt.hist(df["time"], bins=20, color="#4C72B0") -plt.xlabel("Time") -plt.ylabel("Frequency") -plt.title("THMM Event Times") -``` diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md deleted file mode 100644 index 7788093..0000000 --- a/docs/source/getting_started.md +++ /dev/null @@ -1,88 +0,0 @@ ---- -orphan: true ---- - -# Getting Started - -This guide will help you install gen_surv and generate your first survival dataset. - -## Installation - -### From PyPI (Recommended) - -```bash -pip install gen-surv -``` - -### From Source - -```bash -git clone https://github.com/DiogoRibeiro7/genSurvPy.git -cd genSurvPy -poetry install -``` - -```{note} -Some features and tests rely on optional packages such as -`scikit-survival`. Install them with `poetry install --with dev` or -`pip install scikit-survival` (additional system libraries may be -required). -``` - -## Basic Usage - -The main entry point is the `generate()` function: - -```python -from gen_surv import generate - -# Generate Cox proportional hazards data - df = generate( - model="cphm", # Model type - n=100, # Sample size - beta=0.5, # Covariate effect - covariate_range=2.0, # Covariate range - model_cens="uniform", # Censoring type - cens_par=3.0 # Censoring parameter - ) - -print(df.head()) -``` - -## Understanding the Output - -All models return a pandas DataFrame with at least these columns: - -- `time`: Observed event or censoring time -- `status`: Event indicator (1 = event, 0 = censored) -- Additional columns depend on the specific model - -## Command Line Usage - -Generate datasets directly from the terminal: - -```bash -# Generate CPHM data and save to CSV -python -m gen_surv dataset cphm --n 1000 -o survival_data.csv - -# Print AFT data to stdout -python -m gen_surv dataset aft_ln --n 500 -``` - -## Next Steps - -- Explore the {doc}`tutorials/index` for detailed examples -- Check the {doc}`api/index` for complete function documentation -- Read about the {doc}`theory` behind each model - -## Building the Documentation - -To preview the documentation locally run: - -```bash -cd docs -make html -``` - -More details about our Read the Docs configuration can be found in {doc}`rtd`. - diff --git a/docs/source/index.md b/docs/source/index.md index 8bd885f..5a22562 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,141 +1,100 @@ -# gen_surv: Survival Data Simulation in Python +# gen_surv -[![Documentation Status](https://readthedocs.org/projects/gensurvpy/badge/?version=latest)](https://gensurvpy.readthedocs.io/en/latest/?badge=latest) -[![PyPI version](https://badge.fury.io/py/gen-surv.svg)](https://badge.fury.io/py/gen-surv) -[![Python versions](https://img.shields.io/pypi/pyversions/gen-surv.svg)](https://pypi.org/project/gen-surv/) +**gen_surv** is a Python package for simulating survival data under various models, inspired by the R package `genSurv`. -**gen_surv** is a comprehensive Python package for simulating survival data under various statistical models, inspired by the R package `genSurv`. It provides a unified interface for generating synthetic survival datasets that are essential for: +It includes generators for: -- **Research**: Testing new survival analysis methods -- **Education**: Teaching survival analysis concepts -- **Benchmarking**: Comparing different survival models -- **Validation**: Testing statistical software implementations +- **Cox Proportional Hazards Models (CPHM)** +- **Continuous-Time Markov Models (CMM)** +- **Time-Dependent Covariate Models (TDCM)** +- **Time-Homogeneous Hidden Markov Models (THMM)** +- **Accelerated Failure Time (AFT) Log-Normal Models** -```{admonition} Quick Start -:class: tip +Key functions include `generate()`, `gen_cphm()`, `gen_cmm()`, `gen_tdcm()`, +`gen_thmm()`, `gen_aft_log_normal()`, `sample_bivariate_distribution()`, +`runifcens()`, and `rexpocens()`. -Install with pip: -```bash -pip install gen-surv -``` +--- -Generate your first dataset: -```python -from gen_surv import generate -df = generate(model="cphm", n=100, beta=0.5, covariate_range=2.0) -``` -``` +See the [Getting Started](usage) guide for installation instructions. -```{note} -The `to_sksurv` helper and related tests require the optional -dependency `scikit-survival`. Install it with `poetry install --with dev` -or `pip install scikit-survival` if you need this functionality. +## πŸ“š Modules + +```{toctree} +:maxdepth: 2 +:caption: Contents + +usage +modules +theory ``` -## Supported Models -| Model | Description | Use Case | -|-------|-------------|----------| -| **CPHM** | Cox Proportional Hazards | Standard survival regression | -| **AFT** | Accelerated Failure Time | Non-proportional hazards | -| **CMM** | Continuous-Time Markov | Multi-state processes | -| **TDCM** | Time-Dependent Covariates | Dynamic risk factors | -| **THMM** | Time-Homogeneous Markov | Hidden state processes | -| **Competing Risks** | Multiple event types | Cause-specific hazards | -| **Mixture Cure** | Long-term survivors | Logistic cure fraction | -| **Piecewise Exponential** | Piecewise constant hazard | Flexible baseline | +# πŸš€ Usage Example + +```python +from gen_surv import generate -## Algorithm Descriptions +# CPHM +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) -For a brief summary of each statistical model see {doc}`algorithms`. Mathematical -details and notation are provided on the {doc}`theory` page. +# AFT Log-Normal +generate(model="aft_ln", n=100, beta=[0.5, -0.3], sigma=1.0, model_cens="exponential", cens_par=3.0) -## Documentation Contents +# CMM +generate(model="cmm", n=100, model_cens="exponential", cens_par=2.0, + qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0]) -```{toctree} -:maxdepth: 2 +# TDCM +generate(model="tdcm", n=100, dist="weibull", corr=0.5, + dist_par=[1, 2, 1, 2], model_cens="uniform", cens_par=1.0, + beta=[0.1, 0.2, 0.3], lam=1.0) -getting_started -tutorials/index -api/index -theory -algorithms -examples/index -rtd -contributing -changelog -bibliography +# THMM +generate(model="thmm", n=100, qmat=[[0, 0.2, 0], [0.1, 0, 0.1], [0, 0.3, 0]], + emission_pars={"mu": [0.0, 1.0, 2.0], "sigma": [0.5, 0.5, 0.5]}, + p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0) ``` -## Quick Examples +## ⌨️ Command-Line Usage -### Cox Proportional Hazards Model -```python -import gen_surv as gs - -# Basic CPHM with uniform censoring -df = gs.generate( - model="cphm", - n=500, - beta=0.5, - covariate_range=2.0, - model_cens="uniform", - cens_par=3.0 -) -``` +Generate datasets directly from the terminal: -### Accelerated Failure Time Model -```python -# AFT with log-normal distribution -df = gs.generate( - model="aft_ln", - n=200, - beta=[0.5, -0.3, 0.2], - sigma=1.0, - model_cens="exponential", - cens_par=2.0 -) +```bash +python -m gen_surv dataset aft_ln --n 100 > data.csv ``` -### Multi-State Markov Model -```python -# Three-state illness-death model -df = gs.generate( - model="cmm", - n=300, - qmat=[[0, 0.1], [0.05, 0]], - p0=[1.0, 0.0], - model_cens="uniform", - cens_par=5.0 -) +## Repository Layout + +```text +genSurvPy/ +β”œβ”€β”€ gen_surv/ +β”‚ └── ... +β”œβ”€β”€ tests/ +β”œβ”€β”€ examples/ +β”œβ”€β”€ docs/ +β”œβ”€β”€ scripts/ +β”œβ”€β”€ tasks.py +└── TODO.md ``` -## Key Features +## πŸ”— Project Links + +- [Source Code](https://github.com/DiogoRibeiro7/genSurvPy) +- [License](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENCE) +- [Code of Conduct](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/CODE_OF_CONDUCT.md) -- **Unified Interface**: Single `generate()` function for all models -- **Flexible Censoring**: Support for uniform and exponential censoring -- **Rich Parameterization**: Extensive customization options -- **Command-Line Interface**: Generate datasets from terminal -- **Comprehensive Validation**: Input parameter checking -- **Educational Focus**: Clear mathematical documentation ## Citation -If you use gen_surv in your research, please cite: +If you use **gen_surv** in your work, please cite it using the metadata in +[CITATION.cff](../../CITATION.cff). -```bibtex -@software{ribeiro2025gensurvpy, - title = {gen_surv: Survival Data Simulation in Python}, - author = {Diogo Ribeiro}, - year = {2025}, - url = {https://github.com/DiogoRibeiro7/genSurvPy}, - version = {1.0.9} -} -``` +## Author -## License +**Diogo Ribeiro** β€” [ESMAD - Instituto PolitΓ©cnico do Porto](https://esmad.ipp.pt) -MIT License - see [LICENSE](https://github.com/DiogoRibeiro7/genSurvPy/blob/main/LICENSE) for details. +- ORCID: +- Professional email: +- Personal email: -For foundational papers related to these models see the {doc}`bibliography`. -Information on building the docs is provided in the {doc}`rtd` page. diff --git a/docs/source/modules.md b/docs/source/modules.md index f90396b..114a344 100644 --- a/docs/source/modules.md +++ b/docs/source/modules.md @@ -1,66 +1,51 @@ ---- -orphan: true ---- - # API Reference ::: gen_surv.cphm options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.cmm options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.tdcm options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.thmm options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.interface options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.aft options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.bivariate options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.censoring options: members: true undoc-members: true - show-inheritance: true ::: gen_surv.cli options: members: true undoc-members: true - show-inheritance: true -::: gen_surv.validation +::: gen_surv.validate options: members: true undoc-members: true - show-inheritance: true - diff --git a/docs/source/rtd.md b/docs/source/rtd.md deleted file mode 100644 index 7eac430..0000000 --- a/docs/source/rtd.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -orphan: true ---- - -# Read the Docs - -This project uses [Read the Docs](https://readthedocs.org/) to host its documentation. The site is automatically rebuilt whenever changes are pushed to the repository. - -Our build configuration is defined in `.readthedocs.yml`. It installs the package with the `docs` dependency group and builds the Sphinx docs using Python 3.11. - -## Building Locally - -To preview the documentation on your machine, run: - -```bash -cd docs -make html -``` - -Open `build/html/index.html` in your browser to view the result. diff --git a/docs/source/theory.md b/docs/source/theory.md index 1101034..957dce1 100644 --- a/docs/source/theory.md +++ b/docs/source/theory.md @@ -1,7 +1,3 @@ ---- -orphan: true ---- - # πŸ“˜ Mathematical Foundations of `gen_surv` This page presents the mathematical formulation behind the survival models implemented in the `gen_surv` package. @@ -10,9 +6,7 @@ This page presents the mathematical formulation behind the survival models imple ## 1. Cox Proportional Hazards Model (CPHM) -This semi-parametric approach models the hazard as a baseline component -multiplied by an exponential term involving the covariates. The hazard -function conditioned on covariates is: +The hazard function conditioned on covariates is: $$ h(t \mid X) = h_0(t) \exp(X \\beta) @@ -45,8 +39,7 @@ $$ ## 2. Time-Dependent Covariate Model (TDCM) -This extension of the Cox model allows covariate values to vary during -follow-up, accommodating exposures or treatments that change over time: +A generalization of CPHM where covariates change over time: $$ h(t \mid Z(t)) = h_0(t) \\exp(Z(t) \\beta) @@ -58,8 +51,7 @@ In this package, piecewise covariate values are simulated with dependence across ## 3. Continuous-Time Multi-State Markov Model (CMM) -This framework captures transitions between a finite set of states where -waiting times are exponentially distributed. With generator matrix \( Q \), the transition probability matrix is given by: +Markov model with generator matrix \( Q \). The transition probability matrix is given by: $$ P(t) = \\exp(Qt) @@ -74,9 +66,7 @@ Where: ## 4. Time-Homogeneous Hidden Markov Model (THMM) -This model handles situations where the process evolves through unobserved -states that generate the observed responses. It simulates observed states with -latent transitions. +This model simulates observed states with unobserved latent state transitions. Let: @@ -94,9 +84,7 @@ $$ ## 5. Accelerated Failure Time (AFT) Models -These fully parametric models relate covariates to the logarithm of the -survival time. They assume the effect of a covariate speeds up or slows down the -event time directly, rather than acting on the hazard. +AFT models assume that the effect of covariates accelerates or decelerates time to event directly, rather than the hazard. ### Log-Normal AFT @@ -128,25 +116,5 @@ This model is especially useful when the proportional hazards assumption is not All models support censoring: -- **Uniform:** \( C_i \sim U(0, \text{cens\_par}) \) -- **Exponential:** \( C_i \sim \text{Exp}(\text{cens\_par}) \) - -## 6. Competing Risks Models - -These models handle scenarios where several distinct failure types can occur. -Each cause has its own hazard function, and the observed status indicates which -event occurred (1, 2, ...). The package includes constant-hazard and -Weibull-hazard versions. - -## 7. Mixture Cure Models - -These models posit that a subset of the population is cured and will never -experience the event of interest. The generator mixes a logistic cure component -with an exponential hazard for the uncured, returning a ``cured`` indicator -column alongside the usual time and status. - -## 8. Piecewise Exponential Model - -Here the baseline hazard is assumed constant within each of several -user-specified intervals. This allows flexible hazard shapes over time while -remaining easy to simulate. +- **Uniform:** \( C_i \\sim U(0, \\text{cens\\_par}) \) +- **Exponential:** \( C_i \\sim \\text{Exp}(\\text{cens\\_par}) \) diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md deleted file mode 100644 index e99fe2d..0000000 --- a/docs/source/tutorials/basic_usage.md +++ /dev/null @@ -1,110 +0,0 @@ -# Basic Usage Tutorial - -This tutorial covers the fundamentals of generating survival data with gen_surv. - -## Your First Dataset - -Let's start with the simplest case - generating data from a Cox proportional hazards model: - -```python -from gen_surv import generate -import pandas as pd - -# Generate basic CPHM data - df = generate( - model="cphm", - n=200, - beta=0.7, - covariate_range=1.5, - model_cens="exponential", - cens_par=2.0, - seed=42 # For reproducibility - ) - -print(f"Dataset shape: {df.shape}") -print(f"Event rate: {df['status'].mean():.2%}") -print("\nFirst 5 rows:") -print(df.head()) -``` - -## Understanding Parameters - -### Common Parameters - -All models share these parameters: - -- `n`: Sample size (number of individuals) -- `model_cens`: Censoring type ("uniform" or "exponential") -- `cens_par`: Censoring distribution parameter -- `seed`: Random seed for reproducibility - -### Model-Specific Parameters - -Each model has unique parameters. For CPHM: - -- `beta`: Covariate effect (hazard ratio = exp(beta)) -- `covariate_range`: Range for uniform covariate generation [0, covariate_range] - -## Censoring Mechanisms - -gen_surv supports two censoring types: - -### Uniform Censoring -```python -# Censoring times uniformly distributed on [0, cens_par] -df_uniform = generate( - model="cphm", - n=100, - beta=0.5, - covariate_range=2.0, - model_cens="uniform", - cens_par=3.0 -) -``` - -### Exponential Censoring -```python -# Censoring times exponentially distributed with mean cens_par -df_exponential = generate( - model="cphm", - n=100, - beta=0.5, - covariate_range=2.0, - model_cens="exponential", - cens_par=2.0 -) -``` - -## Exploring Your Data - -Basic data exploration: - -```python -import matplotlib.pyplot as plt -import seaborn as sns - -# Event rate by covariate level -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) - -# Histogram of survival times -ax1.hist(df['time'], bins=20, alpha=0.7, edgecolor='black') -ax1.set_xlabel('Time') -ax1.set_ylabel('Frequency') -ax1.set_title('Distribution of Observed Times') - -# Event rate vs covariate -df['covariate_bin'] = pd.cut(df['covariate'], bins=5) -event_rate = df.groupby('covariate_bin')['status'].mean() -event_rate.plot(kind='bar', ax=ax2, rot=45) -ax2.set_ylabel('Event Rate') -ax2.set_title('Event Rate by Covariate Level') - -plt.tight_layout() -plt.show() -``` - -## Next Steps - -- Try different models (model_comparison) -- Learn advanced features (advanced_features) -- See integration examples (integration_examples) diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md deleted file mode 100644 index bc3af37..0000000 --- a/docs/source/tutorials/index.md +++ /dev/null @@ -1,13 +0,0 @@ ---- -orphan: true ---- - -# Tutorials - -Step-by-step guides for using gen_surv effectively. - -```{toctree} -:maxdepth: 2 - -basic_usage -``` diff --git a/docs/source/usage.md b/docs/source/usage.md index 69afc64..e3f856e 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -1,7 +1,3 @@ ---- -orphan: true ---- - # Getting Started This page offers a quick introduction to installing and using **gen_surv**. @@ -21,20 +17,10 @@ This will create a virtual environment and install all required packages. Generate datasets directly in Python: ```python -from gen_surv import export_dataset, generate +from gen_surv import generate # Cox Proportional Hazards example -df = generate( - model="cphm", - n=100, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=2.0, -) - -# Save to RDS for use in R -export_dataset(df, "simulated_data.rds") +generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) ``` You can also generate data from the command line: @@ -45,45 +31,3 @@ python -m gen_surv dataset aft_ln --n 100 > data.csv For a full description of available models and parameters, see the API reference. - -## Building the Documentation - -Documentation is written using [Sphinx](https://www.sphinx-doc.org). To build the HTML pages locally run: - -```bash -cd docs -make html -``` - -The generated files will be available under `docs/build/html`. - -## Scikit-learn Integration - -You can wrap the generator in a transformer compatible with scikit-learn: - -```python -from gen_surv import GenSurvDataGenerator - -est = GenSurvDataGenerator("cphm", n=10, beta=0.5, covariate_range=1.0) -df = est.fit_transform() -``` - -## Lifelines and scikit-survival - -Datasets generated with **gen_surv** can be directly used with -[lifelines](https://lifelines.readthedocs.io). For -[scikit-survival](https://scikit-survival.readthedocs.io) you can convert the -DataFrame using ``to_sksurv``: - -```{note} -The ``to_sksurv`` helper requires the optional dependency -``scikit-survival``. Install it with `poetry install --with dev` or -``pip install scikit-survival``. -``` - -```python -from gen_surv import to_sksurv - -struct = to_sksurv(df) -``` - diff --git a/examples/notebooks/cmm.ipynb b/examples/notebooks/cmm.ipynb deleted file mode 100644 index 303b3ff..0000000 --- a/examples/notebooks/cmm.ipynb +++ /dev/null @@ -1,45 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "metadata": {}, - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from gen_surv import generate\n", - "\n", - "np.random.seed(0)\n", - "df = generate(\n", - " model=\"cmm\",\n", - " n=200,\n", - " model_cens=\"exponential\",\n", - " cens_par=2.0,\n", - " beta=[0.1, 0.2, 0.3],\n", - " covariate_range=1.0,\n", - " rate=[0.1, 1.0, 0.2, 1.0, 0.1, 1.0],\n", - ")\n", - "df.head()\n", - "plt.hist(df['stop'], bins=20, color='#4C72B0')\n", - "plt.xlabel('Time')\n", - "plt.ylabel('Frequency')\n", - "plt.title('CMM Transition Times')\n", - "plt.show()\n" - ], - "outputs": [], - "execution_count": null - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/notebooks/tdcm.ipynb b/examples/notebooks/tdcm.ipynb deleted file mode 100644 index 0f9a91f..0000000 --- a/examples/notebooks/tdcm.ipynb +++ /dev/null @@ -1,47 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "metadata": {}, - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from gen_surv import generate\n", - "\n", - "np.random.seed(0)\n", - "df = generate(\n", - " model=\"tdcm\",\n", - " n=200,\n", - " dist=\"weibull\",\n", - " corr=0.5,\n", - " dist_par=[1, 2, 1, 2],\n", - " model_cens=\"uniform\",\n", - " cens_par=1.0,\n", - " beta=[0.1, 0.2, 0.3],\n", - " lam=1.0,\n", - ")\n", - "df.head()\n", - "plt.hist(df['stop'], bins=20, color='#4C72B0')\n", - "plt.xlabel('Time')\n", - "plt.ylabel('Frequency')\n", - "plt.title('TDCM Event Times')\n", - "plt.show()\n" - ], - "outputs": [], - "execution_count": null - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/notebooks/thmm.ipynb b/examples/notebooks/thmm.ipynb deleted file mode 100644 index c46463c..0000000 --- a/examples/notebooks/thmm.ipynb +++ /dev/null @@ -1,45 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "metadata": {}, - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from gen_surv import generate\n", - "\n", - "np.random.seed(0)\n", - "df = generate(\n", - " model=\"thmm\",\n", - " n=200,\n", - " model_cens=\"exponential\",\n", - " cens_par=3.0,\n", - " beta=[0.1, 0.2, 0.3],\n", - " covariate_range=1.0,\n", - " rate=[0.2, 0.1, 0.3],\n", - ")\n", - "df.head()\n", - "plt.hist(df['time'], bins=20, color='#4C72B0')\n", - "plt.xlabel('Time')\n", - "plt.ylabel('Frequency')\n", - "plt.title('THMM Event Times')\n", - "plt.show()\n" - ], - "outputs": [], - "execution_count": null - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/run_aft.py b/examples/run_aft.py index c2be370..5ba6388 100644 --- a/examples/run_aft.py +++ b/examples/run_aft.py @@ -1,7 +1,6 @@ -import os import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from gen_surv.interface import generate # Generate synthetic survival data using Log-Normal AFT model @@ -12,7 +11,7 @@ sigma=1.0, model_cens="exponential", cens_par=3.0, - seed=123, + seed=123 ) print(df.head()) diff --git a/examples/run_aft_weibull.py b/examples/run_aft_weibull.py deleted file mode 100644 index b213472..0000000 --- a/examples/run_aft_weibull.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Example demonstrating Weibull AFT model and visualization capabilities. -""" - -import os -import sys - -import matplotlib.pyplot as plt -import pandas as pd - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from gen_surv import generate -from gen_surv.visualization import ( - describe_survival, - plot_covariate_effect, - plot_hazard_comparison, - plot_survival_curve, -) - -# 1. Generate data from different models for comparison -models = { - "Weibull AFT (shape=0.5)": generate( - model="aft_weibull", - n=200, - beta=[0.5, -0.3], - shape=0.5, # Decreasing hazard - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ), - "Weibull AFT (shape=1.0)": generate( - model="aft_weibull", - n=200, - beta=[0.5, -0.3], - shape=1.0, # Constant hazard - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ), - "Weibull AFT (shape=2.0)": generate( - model="aft_weibull", - n=200, - beta=[0.5, -0.3], - shape=2.0, # Increasing hazard - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ), -} - -# Print sample data -print("Sample data from Weibull AFT model (shape=2.0):") -print(models["Weibull AFT (shape=2.0)"].head()) -print("\n") - -# 2. Compare survival curves from different models -fig1, ax1 = plot_survival_curve( - data=pd.concat([df.assign(_model=name) for name, df in models.items()]), - group_col="_model", - title="Comparing Survival Curves with Different Weibull Shapes", -) -plt.savefig("survival_curve_comparison.png", dpi=300, bbox_inches="tight") - -# 3. Compare hazard functions -fig2, ax2 = plot_hazard_comparison( - models=models, title="Comparing Hazard Functions with Different Weibull Shapes" -) -plt.savefig("hazard_comparison.png", dpi=300, bbox_inches="tight") - -# 4. Visualize covariate effect on survival -fig3, ax3 = plot_covariate_effect( - data=models["Weibull AFT (shape=2.0)"], - covariate_col="X0", - n_groups=3, - title="Effect of X0 Covariate on Survival", -) -plt.savefig("covariate_effect.png", dpi=300, bbox_inches="tight") - -# 5. Summary statistics -for name, df in models.items(): - print(f"Summary for {name}:") - summary = describe_survival(df) - print(summary) - print("\n") - -print("Plots saved to current directory.") - -# Show plots if running interactively -if __name__ == "__main__": - plt.show() diff --git a/examples/run_cmm.py b/examples/run_cmm.py index 52a11ff..590b7a1 100644 --- a/examples/run_cmm.py +++ b/examples/run_cmm.py @@ -1,7 +1,6 @@ -import os import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from gen_surv import generate @@ -12,7 +11,7 @@ cens_par=2.0, qmat=[[0, 0.1], [0.05, 0]], p0=[1.0, 0.0], - seed=42, + seed=42 ) print(df.head()) diff --git a/examples/run_competing_risks.py b/examples/run_competing_risks.py deleted file mode 100644 index a22871a..0000000 --- a/examples/run_competing_risks.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Example demonstrating the Competing Risks models and visualization. -""" - -import os -import sys - -import matplotlib.pyplot as plt -import numpy as np - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from gen_surv import generate -from gen_surv.competing_risks import ( - cause_specific_cumulative_incidence, - gen_competing_risks, - gen_competing_risks_weibull, -) -from gen_surv.summary import compare_survival_datasets, summarize_survival_dataset - - -def plot_cause_specific_cumulative_incidence(df, time_points=None, figsize=(10, 6)): - """Plot the cause-specific cumulative incidence functions.""" - if time_points is None: - max_time = df["time"].max() - time_points = np.linspace(0, max_time, 100) - - # Get unique causes (excluding censoring) - causes = sorted([c for c in df["status"].unique() if c > 0]) - - # Create the plot - fig, ax = plt.subplots(figsize=figsize) - - for cause in causes: - cif = cause_specific_cumulative_incidence(df, time_points, cause=cause) - ax.plot(cif["time"], cif["incidence"], label=f"Cause {cause}") - - # Add overlay showing number of subjects at each time - time_bins = np.linspace(0, df["time"].max(), 10) - event_counts = np.histogram(df.loc[df["status"] > 0, "time"], bins=time_bins)[0] - - # Add a secondary y-axis for event counts - ax2 = ax.twinx() - ax2.bar( - time_bins[:-1], - event_counts, - width=time_bins[1] - time_bins[0], - alpha=0.2, - color="gray", - align="edge", - ) - ax2.set_ylabel("Number of events") - ax2.grid(False) - - # Format the main plot - ax.set_xlabel("Time") - ax.set_ylabel("Cumulative Incidence") - ax.set_title("Cause-Specific Cumulative Incidence Functions") - ax.legend() - ax.grid(alpha=0.3) - - return fig, ax - - -# 1. Generate data with 2 competing risks -print("Generating data with exponential hazards...") -data_exponential = gen_competing_risks( - n=500, - n_risks=2, - baseline_hazards=[0.5, 0.3], - betas=[[0.8, -0.5], [0.2, 0.7]], - model_cens="uniform", - cens_par=2.0, - seed=42, -) - -# 2. Generate data with Weibull hazards (different shapes) -print("Generating data with Weibull hazards...") -data_weibull = gen_competing_risks_weibull( - n=500, - n_risks=2, - shape_params=[0.8, 1.5], # Decreasing vs increasing hazard - scale_params=[2.0, 3.0], - betas=[[0.8, -0.5], [0.2, 0.7]], - model_cens="uniform", - cens_par=2.0, - seed=42, -) - -# 3. Print summary statistics for both datasets -print("\nSummary of Exponential Hazards dataset:") -summarize_survival_dataset(data_exponential) - -print("\nSummary of Weibull Hazards dataset:") -summarize_survival_dataset(data_weibull) - -# 4. Compare event distributions -print("\nEvent distribution (Exponential Hazards):") -print(data_exponential["status"].value_counts()) - -print("\nEvent distribution (Weibull Hazards):") -print(data_weibull["status"].value_counts()) - -# 5. Plot cause-specific cumulative incidence functions -print("\nPlotting cumulative incidence functions...") -time_points = np.linspace(0, 5, 100) - -fig1, ax1 = plot_cause_specific_cumulative_incidence( - data_exponential, time_points=time_points, figsize=(10, 6) -) -plt.title("Cumulative Incidence Functions (Exponential Hazards)") -plt.savefig("cr_exponential_cif.png", dpi=300, bbox_inches="tight") - -fig2, ax2 = plot_cause_specific_cumulative_incidence( - data_weibull, time_points=time_points, figsize=(10, 6) -) -plt.title("Cumulative Incidence Functions (Weibull Hazards)") -plt.savefig("cr_weibull_cif.png", dpi=300, bbox_inches="tight") - -# 6. Demonstrate using the unified generate() interface -print("\nUsing the unified generate() interface:") -data_unified = generate( - model="competing_risks", - n=100, - n_risks=2, - baseline_hazards=[0.5, 0.3], - betas=[[0.8, -0.5], [0.2, 0.7]], - model_cens="uniform", - cens_par=2.0, - seed=42, -) -print(data_unified.head()) - -# 7. Compare datasets -print("\nComparing datasets:") -comparison = compare_survival_datasets( - {"Exponential": data_exponential, "Weibull": data_weibull} -) -print(comparison) - -# Show plots if running interactively -if __name__ == "__main__": - plt.show() diff --git a/examples/run_cphm.py b/examples/run_cphm.py index ebcc138..c02b01b 100644 --- a/examples/run_cphm.py +++ b/examples/run_cphm.py @@ -1,7 +1,6 @@ -import os import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from gen_surv import generate @@ -11,8 +10,8 @@ model_cens="uniform", cens_par=1.0, beta=0.5, - covariate_range=2.0, - seed=42, + covar=2.0, + seed=42 ) print(df.head()) diff --git a/examples/run_tdcm.py b/examples/run_tdcm.py index c05ccf9..dd5204c 100644 --- a/examples/run_tdcm.py +++ b/examples/run_tdcm.py @@ -1,7 +1,6 @@ -import os import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from gen_surv import generate @@ -15,7 +14,7 @@ cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0, - seed=42, + seed=42 ) print(df.head()) diff --git a/examples/run_thmm.py b/examples/run_thmm.py index 038699d..73721ad 100644 --- a/examples/run_thmm.py +++ b/examples/run_thmm.py @@ -1,7 +1,6 @@ -import os import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from gen_surv import generate @@ -13,7 +12,7 @@ p0=[1.0, 0.0, 0.0], model_cens="exponential", cens_par=3.0, - seed=42, + seed=42 ) print(df.head()) diff --git a/gen_surv/__init__.py b/gen_surv/__init__.py index 65f9723..8939886 100644 --- a/gen_surv/__init__.py +++ b/gen_surv/__init__.py @@ -1,97 +1,17 @@ """Top-level package for ``gen_surv``. -This module exposes the main functions and provides access to the package version. +This module exposes the :func:`generate` function and provides access to the +package version via ``__version__``. """ from importlib.metadata import PackageNotFoundError, version -from .aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull - -# Helper functions -from .bivariate import sample_bivariate_distribution -from .censoring import ( - CensoringModel, - GammaCensoring, - LogNormalCensoring, - WeibullCensoring, - rexpocens, - rgammacens, - rlognormcens, - runifcens, - rweibcens, -) -from .cmm import gen_cmm -from .competing_risks import gen_competing_risks, gen_competing_risks_weibull - -# Individual generators -from .cphm import gen_cphm -from .export import export_dataset -from .integration import to_sksurv - -# Main interface from .interface import generate -from .mixture import cure_fraction_estimate, gen_mixture_cure -from .piecewise import gen_piecewise_exponential -from .sklearn_adapter import GenSurvDataGenerator -from .tdcm import gen_tdcm -from .thmm import gen_thmm - -# Visualization tools (requires matplotlib and lifelines) -try: - from .visualization import describe_survival # noqa: F401 - from .visualization import plot_covariate_effect # noqa: F401 - from .visualization import plot_hazard_comparison # noqa: F401 - from .visualization import plot_survival_curve # noqa: F401 - - _has_visualization = True -except ImportError: - _has_visualization = False try: __version__ = version("gen_surv") except PackageNotFoundError: # pragma: no cover - fallback when package not installed __version__ = "0.0.0" -__all__ = [ - # Main interface - "generate", - "__version__", - # Individual generators - "gen_cphm", - "gen_cmm", - "gen_tdcm", - "gen_thmm", - "gen_aft_log_normal", - "gen_aft_weibull", - "gen_aft_log_logistic", - "gen_competing_risks", - "gen_competing_risks_weibull", - "gen_mixture_cure", - "cure_fraction_estimate", - "gen_piecewise_exponential", - # Helpers - "sample_bivariate_distribution", - "runifcens", - "rexpocens", - "rweibcens", - "rlognormcens", - "rgammacens", - "WeibullCensoring", - "LogNormalCensoring", - "GammaCensoring", - "CensoringModel", - "export_dataset", - "to_sksurv", - "GenSurvDataGenerator", -] +__all__ = ["generate", "__version__"] -# Add visualization tools to __all__ if available -if _has_visualization: - __all__.extend( - [ - "plot_survival_curve", - "plot_hazard_comparison", - "plot_covariate_effect", - "describe_survival", - ] - ) diff --git a/gen_surv/_covariates.py b/gen_surv/_covariates.py deleted file mode 100644 index 69b7b81..0000000 --- a/gen_surv/_covariates.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Utilities for generating covariate matrices with validation.""" - -from typing import Literal - -import numpy as np -from numpy.random import Generator -from numpy.typing import NDArray - -from .validation import ParameterError, ensure_positive - -_CovParams = dict[str, float] - - -def set_covariate_params( - covariate_dist: Literal["normal", "uniform", "binary"], - covariate_params: _CovParams | None, -) -> _CovParams: - """Return covariate distribution parameters with defaults filled in.""" - if covariate_params is not None: - return covariate_params - if covariate_dist == "normal": - return {"mean": 0.0, "std": 1.0} - if covariate_dist == "uniform": - return {"low": 0.0, "high": 1.0} - if covariate_dist == "binary": - return {"p": 0.5} - raise ParameterError( - "covariate_dist", - covariate_dist, - "unsupported covariate distribution; choose from 'normal', 'uniform', or 'binary'", - ) - - -def _get_float(params: _CovParams, key: str, default: float) -> float: - val = params.get(key, default) - if not isinstance(val, (int, float)): - raise ParameterError(f"covariate_params['{key}']", val, "must be a number") - return float(val) - - -def generate_covariates( - n: int, - n_covariates: int, - covariate_dist: Literal["normal", "uniform", "binary"], - covariate_params: _CovParams, - rng: Generator, -) -> NDArray[np.float64]: - """Generate covariate matrix according to the specified distribution.""" - if covariate_dist == "normal": - std = _get_float(covariate_params, "std", 1.0) - ensure_positive(std, "covariate_params['std']") - mean = _get_float(covariate_params, "mean", 0.0) - return rng.normal(mean, std, size=(n, n_covariates)) - if covariate_dist == "uniform": - low = _get_float(covariate_params, "low", 0.0) - high = _get_float(covariate_params, "high", 1.0) - if high <= low: - raise ParameterError( - "covariate_params['high']", - high, - "must be greater than 'low'", - ) - return rng.uniform(low, high, size=(n, n_covariates)) - if covariate_dist == "binary": - p = _get_float(covariate_params, "p", 0.5) - if not 0 <= p <= 1: - raise ParameterError( - "covariate_params['p']", - p, - "must be between 0 and 1", - ) - return rng.binomial(1, p, size=(n, n_covariates)).astype(float) - raise ParameterError( - "covariate_dist", - covariate_dist, - "unsupported covariate distribution; choose from 'normal', 'uniform', or 'binary'", - ) diff --git a/gen_surv/aft.py b/gen_surv/aft.py index afbf0a3..5c85fb9 100644 --- a/gen_surv/aft.py +++ b/gen_surv/aft.py @@ -1,212 +1,46 @@ -""" -Accelerated Failure Time (AFT) models including Weibull, Log-Normal, and Log-Logistic distributions. -""" - -from typing import List, Literal, Optional - import numpy as np import pandas as pd -from .censoring import rexpocens, runifcens -from .validation import ensure_censoring_model, ensure_positive - -def gen_aft_log_normal( - n: int, - beta: List[float], - sigma: float, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - seed: Optional[int] = None, -) -> pd.DataFrame: +def gen_aft_log_normal(n, beta, sigma, model_cens, cens_par, seed=None): """ Simulate survival data under a Log-Normal Accelerated Failure Time (AFT) model. - Parameters - ---------- - n : int - Number of individuals - beta : list of float - Coefficients for covariates - sigma : float - Standard deviation of the log-error term - model_cens : {"uniform", "exponential"} - Censoring mechanism - cens_par : float - Parameter for censoring distribution - seed : int, optional - Random seed for reproducibility + Parameters: + - n (int): Number of individuals + - beta (list of float): Coefficients for covariates + - sigma (float): Standard deviation of the log-error term + - model_cens (str): 'uniform' or 'exponential' + - cens_par (float): Parameter for censoring distribution + - seed (int, optional): Random seed - Returns - ------- - pd.DataFrame - DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] + Returns: + - pd.DataFrame: DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] """ - rng = np.random.default_rng(seed) + if seed is not None: + np.random.seed(seed) p = len(beta) - X = rng.normal(size=(n, p)) - epsilon = rng.normal(loc=0.0, scale=sigma, size=n) + X = np.random.normal(size=(n, p)) + epsilon = np.random.normal(loc=0.0, scale=sigma, size=n) log_T = X @ np.array(beta) + epsilon T = np.exp(log_T) - ensure_censoring_model(model_cens) - rfunc = runifcens if model_cens == "uniform" else rexpocens - C = rfunc(n, cens_par, rng) - - observed_time = np.minimum(T, C) - status = (T <= C).astype(int) - - data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) - - for j in range(p): - data[f"X{j}"] = X[:, j] - - return data - - -def gen_aft_weibull( - n: int, - beta: List[float], - shape: float, - scale: float, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - seed: Optional[int] = None, -) -> pd.DataFrame: - """ - Simulate survival data under a Weibull Accelerated Failure Time (AFT) model. - - The Weibull AFT model has survival function: - S(t|X) = exp(-(t/scale)^shape * exp(-X*beta)) - - Parameters - ---------- - n : int - Number of individuals - beta : list of float - Coefficients for covariates - shape : float - Weibull shape parameter (k > 0) - scale : float - Weibull scale parameter (Ξ» > 0) - model_cens : {"uniform", "exponential"} - Censoring mechanism - cens_par : float - Parameter for censoring distribution - seed : int, optional - Random seed for reproducibility - - Returns - ------- - pd.DataFrame - DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] - """ - rng = np.random.default_rng(seed) - - ensure_positive(shape, "shape") - ensure_positive(scale, "scale") - - p = len(beta) - X = rng.normal(size=(n, p)) - - # Linear predictor - eta = X @ np.array(beta) - - # Generate Weibull survival times - U = rng.uniform(size=n) - T = scale * (-np.log(U) * np.exp(-eta)) ** (1 / shape) - - # Generate censoring times - ensure_censoring_model(model_cens) - rfunc = runifcens if model_cens == "uniform" else rexpocens - C = rfunc(n, cens_par, rng) - - # Observed time is the minimum of event time and censoring time - observed_time = np.minimum(T, C) - status = (T <= C).astype(int) - - data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) - - for j in range(p): - data[f"X{j}"] = X[:, j] - - return data - - -def gen_aft_log_logistic( - n: int, - beta: List[float], - shape: float, - scale: float, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - seed: Optional[int] = None, -) -> pd.DataFrame: - """ - Simulate survival data under a Log-Logistic Accelerated Failure Time (AFT) model. - - The Log-Logistic AFT model has survival function: - S(t|X) = 1 / (1 + (t/scale)^shape * exp(X*beta)) - - Log-logistic distribution is useful when the hazard rate first increases and then decreases. - - Parameters - ---------- - n : int - Number of individuals - beta : list of float - Coefficients for covariates - shape : float - Log-logistic shape parameter (Ξ± > 0) - scale : float - Log-logistic scale parameter (Ξ² > 0) - model_cens : {"uniform", "exponential"} - Censoring mechanism - cens_par : float - Parameter for censoring distribution - seed : int, optional - Random seed for reproducibility - - Returns - ------- - pd.DataFrame - DataFrame with columns ['id', 'time', 'status', 'X0', ..., 'Xp'] - """ - rng = np.random.default_rng(seed) - - ensure_positive(shape, "shape") - ensure_positive(scale, "scale") - - p = len(beta) - X = rng.normal(size=(n, p)) - - # Linear predictor - eta = X @ np.array(beta) - - # Generate Log-Logistic survival times - U = rng.uniform(size=n) - - # Inverse CDF method: S(t) = 1/(1 + (t/scale)^shape) - # so t = scale * (1/S - 1)^(1/shape) - # For random U ~ Uniform(0,1), we can use U as 1-S - # t = scale * (1/(1-U) - 1)^(1/shape) * exp(-eta/shape) - # simplifies to: t = scale * (U/(1-U))^(1/shape) * exp(-eta/shape) - - # Avoid numerical issues near 1 - U = np.clip(U, 0.001, 0.999) - T = scale * (U / (1 - U)) ** (1 / shape) * np.exp(-eta / shape) - - # Generate censoring times - ensure_censoring_model(model_cens) - rfunc = runifcens if model_cens == "uniform" else rexpocens - C = rfunc(n, cens_par, rng) + if model_cens == "uniform": + C = np.random.uniform(0, cens_par, size=n) + elif model_cens == "exponential": + C = np.random.exponential(scale=cens_par, size=n) + else: + raise ValueError("model_cens must be 'uniform' or 'exponential'") - # Observed time is the minimum of event time and censoring time observed_time = np.minimum(T, C) status = (T <= C).astype(int) - data = pd.DataFrame({"id": np.arange(n), "time": observed_time, "status": status}) + data = pd.DataFrame({ + "id": np.arange(n), + "time": observed_time, + "status": status + }) for j in range(p): data[f"X{j}"] = X[:, j] diff --git a/gen_surv/bivariate.py b/gen_surv/bivariate.py index 4fea539..070fbd0 100644 --- a/gen_surv/bivariate.py +++ b/gen_surv/bivariate.py @@ -1,59 +1,38 @@ -from typing import Sequence - import numpy as np -from numpy.typing import NDArray - -from .validate import validate_dg_biv_inputs - -_CHI2_SCALE = 0.5 -_CLIP_EPS = 1e-10 - -def sample_bivariate_distribution( - n: int, dist: str, corr: float, dist_par: Sequence[float] -) -> NDArray[np.float64]: - """Draw correlated samples from Weibull or exponential marginals. - - Parameters - ---------- - n : int - Number of samples to generate. - dist : {"weibull", "exponential"} - Type of marginal distributions. - corr : float - Correlation coefficient. - dist_par : Sequence[float] - Distribution parameters ``[a1, b1, a2, b2]`` for the Weibull case or - ``[lambda1, lambda2]`` for the exponential case. +def sample_bivariate_distribution(n, dist, corr, dist_par): + """ + Generate samples from a bivariate distribution with specified correlation. - Returns - ------- - NDArray[np.float64] - Array of shape ``(n, 2)`` with the sampled pairs. + Parameters: + - n (int): Number of samples + - dist (str): 'weibull' or 'exponential' + - corr (float): Correlation coefficient between [-1, 1] + - dist_par (list): Parameters for the marginals - Raises - ------ - ValidationError - If ``dist`` is unsupported or ``dist_par`` has an invalid length. + Returns: + - np.ndarray of shape (n, 2) """ - - validate_dg_biv_inputs(n, dist, corr, dist_par) + if dist not in {"weibull", "exponential"}: + raise ValueError("Only 'weibull' and 'exponential' distributions are supported.") # Step 1: Generate correlated standard normals using Cholesky mean = [0, 0] cov = [[1, corr], [corr, 1]] z = np.random.multivariate_normal(mean, cov, size=n) - u = 1 - np.exp( - -_CHI2_SCALE * z**2 - ) # transform normals to uniform via chi-squared approx - u = np.clip(u, _CLIP_EPS, 1 - _CLIP_EPS) # avoid infs in tails + u = 1 - np.exp(-0.5 * z**2) # transform normals to uniform via chi-squared approx + u = np.clip(u, 1e-10, 1 - 1e-10) # avoid infs in tails # Step 2: Transform to marginals if dist == "exponential": + if len(dist_par) != 2: + raise ValueError("Exponential distribution requires 2 positive rate parameters.") x1 = -np.log(1 - u[:, 0]) / dist_par[0] x2 = -np.log(1 - u[:, 1]) / dist_par[1] - else: # dist == "weibull" + elif dist == "weibull": + if len(dist_par) != 4: + raise ValueError("Weibull distribution requires 4 positive parameters [a1, b1, a2, b2].") a1, b1, a2, b2 = dist_par x1 = (-np.log(1 - u[:, 0]) / a1) ** (1 / b1) x2 = (-np.log(1 - u[:, 1]) / a2) ** (1 / b2) diff --git a/gen_surv/censoring.py b/gen_surv/censoring.py index c93eee4..3228825 100644 --- a/gen_surv/censoring.py +++ b/gen_surv/censoring.py @@ -1,31 +1,6 @@ -from typing import Protocol - import numpy as np -from numpy.random import Generator, default_rng -from numpy.typing import NDArray - - -class CensoringFunc(Protocol): - """Protocol for censoring time generators.""" - - def __call__( - self, size: int, cens_par: float, rng: Generator | None = None - ) -> NDArray[np.float64]: - """Generate ``size`` censoring times given ``cens_par``.""" - ... - - -class CensoringModel(Protocol): - """Protocol for class-based censoring generators.""" - - def __call__(self, size: int) -> NDArray[np.float64]: - """Generate ``size`` censoring times.""" - ... - -def runifcens( - size: int, cens_par: float, rng: Generator | None = None -) -> NDArray[np.float64]: +def runifcens(size: int, cens_par: float) -> np.ndarray: """ Generate uniform censoring times. @@ -34,15 +9,11 @@ def runifcens( - cens_par (float): Upper bound for uniform distribution. Returns: - - NDArray of censoring times. + - np.ndarray of censoring times. """ - r = default_rng() if rng is None else rng - return r.uniform(0, cens_par, size) - + return np.random.uniform(0, cens_par, size) -def rexpocens( - size: int, cens_par: float, rng: Generator | None = None -) -> NDArray[np.float64]: +def rexpocens(size: int, cens_par: float) -> np.ndarray: """ Generate exponential censoring times. @@ -51,70 +22,6 @@ def rexpocens( - cens_par (float): Mean of exponential distribution. Returns: - - NDArray of censoring times. + - np.ndarray of censoring times. """ - r = default_rng() if rng is None else rng - return r.exponential(scale=cens_par, size=size) - - -def rweibcens( - size: int, scale: float, shape: float, rng: Generator | None = None -) -> NDArray[np.float64]: - """Generate Weibull-distributed censoring times.""" - r = default_rng() if rng is None else rng - return r.weibull(shape, size) * scale - - -def rlognormcens( - size: int, mean: float, sigma: float, rng: Generator | None = None -) -> NDArray[np.float64]: - """Generate log-normal-distributed censoring times.""" - r = default_rng() if rng is None else rng - return r.lognormal(mean, sigma, size) - - -def rgammacens( - size: int, shape: float, scale: float, rng: Generator | None = None -) -> NDArray[np.float64]: - """Generate Gamma-distributed censoring times.""" - r = default_rng() if rng is None else rng - return r.gamma(shape, scale, size) - - -class WeibullCensoring: - """Class-based generator for Weibull censoring times.""" - - def __init__(self, scale: float, shape: float) -> None: - self.scale = scale - self.shape = shape - - def __call__(self, size: int, rng: Generator | None = None) -> NDArray[np.float64]: - """Generate ``size`` censoring times from a Weibull distribution.""" - r = default_rng() if rng is None else rng - return r.weibull(self.shape, size) * self.scale - - -class LogNormalCensoring: - """Class-based generator for log-normal censoring times.""" - - def __init__(self, mean: float, sigma: float) -> None: - self.mean = mean - self.sigma = sigma - - def __call__(self, size: int, rng: Generator | None = None) -> NDArray[np.float64]: - """Generate ``size`` censoring times from a log-normal distribution.""" - r = default_rng() if rng is None else rng - return r.lognormal(self.mean, self.sigma, size) - - -class GammaCensoring: - """Class-based generator for Gamma censoring times.""" - - def __init__(self, shape: float, scale: float) -> None: - self.shape = shape - self.scale = scale - - def __call__(self, size: int, rng: Generator | None = None) -> NDArray[np.float64]: - """Generate ``size`` censoring times from a Gamma distribution.""" - r = default_rng() if rng is None else rng - return r.gamma(self.shape, self.scale, size) + return np.random.exponential(scale=cens_par, size=size) diff --git a/gen_surv/cli.py b/gen_surv/cli.py index a12c073..542ea51 100644 --- a/gen_surv/cli.py +++ b/gen_surv/cli.py @@ -1,233 +1,36 @@ -""" -Command-line interface for gen_surv. - -This module provides a command-line interface for generating survival data -using the gen_surv package. -""" - -from typing import Any, Dict, List, Optional, TypeVar, cast - +import csv +from typing import Optional import typer - from gen_surv.interface import generate app = typer.Typer(help="Generate synthetic survival datasets.") - @app.command() def dataset( model: str = typer.Argument( - ..., - help=( - "Model to simulate [cphm, cmm, tdcm, thmm, aft_ln, aft_weibull, aft_log_logistic, competing_risks, competing_risks_weibull, mixture_cure, piecewise_exponential]" - ), + ..., help="Model to simulate [cphm, cmm, tdcm, thmm, aft_ln]" ), n: int = typer.Option(100, help="Number of samples"), - model_cens: str = typer.Option( - "uniform", help="Censoring model: 'uniform' or 'exponential'" - ), - cens_par: float = typer.Option(1.0, help="Censoring parameter"), - beta: List[float] = typer.Option( - [0.5], - help="Regression coefficient(s). Provide multiple values for multi-parameter models.", - ), - covariate_range: Optional[float] = typer.Option( - 2.0, - "--covariate-range", - "--covar", - help="Upper bound for covariate values (for CPHM, CMM, THMM)", - ), - sigma: Optional[float] = typer.Option( - 1.0, help="Standard deviation parameter (for log-normal AFT)" - ), - shape: Optional[float] = typer.Option( - 1.5, help="Shape parameter (for Weibull AFT)" - ), - scale: Optional[float] = typer.Option( - 2.0, help="Scale parameter (for Weibull AFT)" - ), - n_risks: int = typer.Option(2, help="Number of competing risks"), - baseline_hazards: List[float] = typer.Option( - [], help="Baseline hazards for competing risks" - ), - shape_params: List[float] = typer.Option( - [], help="Shape parameters for Weibull competing risks" - ), - scale_params: List[float] = typer.Option( - [], help="Scale parameters for Weibull competing risks" - ), - cure_fraction: Optional[float] = typer.Option( - None, help="Cure fraction for mixture cure model" - ), - baseline_hazard: Optional[float] = typer.Option( - None, help="Baseline hazard for mixture cure model" - ), - breakpoints: List[float] = typer.Option( - [], help="Breakpoints for piecewise exponential model" - ), - hazard_rates: List[float] = typer.Option( - [], help="Hazard rates for piecewise exponential model" - ), - seed: Optional[int] = typer.Option(None, help="Random seed for reproducibility"), output: Optional[str] = typer.Option( None, "-o", help="Output CSV file. Prints to stdout if omitted." ), ) -> None: """Generate survival data and optionally save to CSV. - Examples: - # Generate data from CPHM model - $ gen_surv dataset cphm --n 100 --beta 0.5 --covariate-range 2.0 -o cphm_data.csv + Args: + model: Identifier of the generator to use. + n: Number of samples to create. + output: Optional path to save the CSV file. - # Generate data from Weibull AFT model - $ gen_surv dataset aft_weibull --n 200 --beta 0.5 --beta -0.3 --shape 1.5 --scale 2.0 -o aft_data.csv + Returns: + None """ - # Helper to unwrap Typer Option defaults when function is called directly - from typer.models import OptionInfo - - T = TypeVar("T") - - def _val(v: T | OptionInfo) -> T: - return v if not isinstance(v, OptionInfo) else cast(T, v.default) - - # Prepare arguments based on the selected model - model_str: str = _val(model) - kwargs: Dict[str, Any] = { - "model": model_str, - "n": _val(n), - "model_cens": _val(model_cens), - "cens_par": _val(cens_par), - "seed": _val(seed), - } - - # Add model-specific parameters - if model_str in ["cphm", "cmm", "thmm"]: - # These models use a single beta and covariate range - beta_values = cast(List[float], _val(beta)) - kwargs["beta"] = beta_values[0] if len(beta_values) > 0 else 0.5 - kwargs["covariate_range"] = _val(covariate_range) - - elif model_str == "aft_ln": - # Log-normal AFT model uses beta list and sigma - kwargs["beta"] = _val(beta) - kwargs["sigma"] = _val(sigma) - - elif model_str == "aft_weibull": - # Weibull AFT model uses beta list, shape, and scale - kwargs["beta"] = _val(beta) - kwargs["shape"] = _val(shape) - kwargs["scale"] = _val(scale) - - elif model_str == "aft_log_logistic": - kwargs["beta"] = _val(beta) - kwargs["shape"] = _val(shape) - kwargs["scale"] = _val(scale) - - elif model_str == "competing_risks": - kwargs["n_risks"] = _val(n_risks) - if _val(baseline_hazards): - kwargs["baseline_hazards"] = _val(baseline_hazards) - if _val(beta): - kwargs["betas"] = [_val(beta) for _ in range(_val(n_risks))] - - elif model_str == "competing_risks_weibull": - kwargs["n_risks"] = _val(n_risks) - if _val(shape_params): - kwargs["shape_params"] = _val(shape_params) - if _val(scale_params): - kwargs["scale_params"] = _val(scale_params) - if _val(beta): - kwargs["betas"] = [_val(beta) for _ in range(_val(n_risks))] - - elif model_str == "mixture_cure": - if _val(cure_fraction) is not None: - kwargs["cure_fraction"] = _val(cure_fraction) - if _val(baseline_hazard) is not None: - kwargs["baseline_hazard"] = _val(baseline_hazard) - kwargs["betas_survival"] = _val(beta) - kwargs["betas_cure"] = _val(beta) - - elif model_str == "piecewise_exponential": - kwargs["breakpoints"] = _val(breakpoints) - kwargs["hazard_rates"] = _val(hazard_rates) - kwargs["betas"] = _val(beta) - - # Generate the data - df = generate(**kwargs) - - # Output the data + df = generate(model=model, n=n) if output: df.to_csv(output, index=False) typer.echo(f"Saved dataset to {output}") else: typer.echo(df.to_csv(index=False)) - -@app.command() -def visualize( - input_file: str = typer.Argument( - ..., help="Input CSV file containing survival data" - ), - time_col: str = typer.Option("time", help="Column containing time/duration values"), - status_col: str = typer.Option( - "status", help="Column containing event indicator (1=event, 0=censored)" - ), - group_col: Optional[str] = typer.Option( - None, help="Column to use for stratification" - ), - output: str = typer.Option("survival_plot.png", help="Output image file"), -) -> None: - """Visualize survival data from a CSV file. - - Examples: - # Generate a Kaplan-Meier plot from a CSV file - $ gen_surv visualize data.csv --time-col time --status-col status -o km_plot.png - - # Generate a stratified plot using a grouping variable - $ gen_surv visualize data.csv --group-col X0 -o stratified_plot.png - """ - try: - import matplotlib.pyplot as plt - import pandas as pd - - from gen_surv.visualization import plot_survival_curve - except ImportError: - typer.echo( - "Error: Visualization requires matplotlib and lifelines. " - "Install them with: pip install matplotlib lifelines" - ) - raise typer.Exit(1) - - # Load the data - try: - data = pd.read_csv(input_file) - except Exception as e: - typer.echo(f"Error loading CSV file: {str(e)}") - raise typer.Exit(1) - - # Check required columns - if time_col not in data.columns: - typer.echo(f"Error: Time column '{time_col}' not found in data") - raise typer.Exit(1) - - if status_col not in data.columns: - typer.echo(f"Error: Status column '{status_col}' not found in data") - raise typer.Exit(1) - - if group_col is not None and group_col not in data.columns: - typer.echo(f"Error: Group column '{group_col}' not found in data") - raise typer.Exit(1) - - # Create the plot - fig, ax = plot_survival_curve( - data=data, time_col=time_col, status_col=status_col, group_col=group_col - ) - - # Save the plot - plt.savefig(output, dpi=300, bbox_inches="tight") - plt.close(fig) - typer.echo(f"Plot saved to {output}") - - if __name__ == "__main__": app() diff --git a/gen_surv/cmm.py b/gen_surv/cmm.py index 785ea6d..689f2db 100644 --- a/gen_surv/cmm.py +++ b/gen_surv/cmm.py @@ -1,21 +1,10 @@ -from typing import Sequence, TypedDict - -import numpy as np import pandas as pd - -from gen_surv.censoring import CensoringFunc, rexpocens, runifcens -from gen_surv.validation import validate_gen_cmm_inputs - - -class EventTimes(TypedDict): - t12: float - t13: float - t23: float +import numpy as np +from gen_surv.validate import validate_gen_cmm_inputs +from gen_surv.censoring import runifcens, rexpocens -def generate_event_times( - z1: float, beta: Sequence[float], rate: Sequence[float] -) -> EventTimes: +def generate_event_times(z1: float, beta: list, rate: list) -> dict: """ Generate event times for a continuous-time multi-state Markov model. @@ -28,26 +17,17 @@ def generate_event_times( - dict: {'t12': float, 't13': float, 't23': float} """ u = np.random.uniform() - t12 = (-np.log(1 - u) / (rate[0] * np.exp(beta[0] * z1))) ** (1 / rate[1]) + t12 = (-np.log(1 - u) / (rate[0] * np.exp(beta[0] * z1)))**(1 / rate[1]) u = np.random.uniform() - t13 = (-np.log(1 - u) / (rate[2] * np.exp(beta[1] * z1))) ** (1 / rate[3]) + t13 = (-np.log(1 - u) / (rate[2] * np.exp(beta[1] * z1)))**(1 / rate[3]) u = np.random.uniform() - t23 = (-np.log(1 - u) / (rate[4] * np.exp(beta[2] * z1))) ** (1 / rate[5]) + t23 = (-np.log(1 - u) / (rate[4] * np.exp(beta[2] * z1)))**(1 / rate[5]) return {"t12": t12, "t13": t13, "t23": t23} - -def gen_cmm( - n: int, - model_cens: str, - cens_par: float, - beta: Sequence[float], - covariate_range: float, - rate: Sequence[float], - seed: int | None = None, -) -> pd.DataFrame: +def gen_cmm(n, model_cens, cens_par, beta, covar, rate): """ Generate survival data using a continuous-time Markov model (CMM). @@ -56,38 +36,35 @@ def gen_cmm( - model_cens (str): "uniform" or "exponential". - cens_par (float): Parameter for censoring. - beta (list): Regression coefficients (length 3). - - covariate_range (float): Upper bound for the covariate values. + - covar (float): Covariate range (uniformly sampled from [0, covar]). - rate (list): Transition rates (length 6). Returns: - - pd.DataFrame with columns: id, start, stop, status, X0, transition + - pd.DataFrame with columns: id, start, stop, status, covariate, transition """ - validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) - - rng = np.random.default_rng(seed) - rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens - - z1 = rng.uniform(0, covariate_range, size=n) - c = rfunc(n, cens_par, rng) - - u = rng.uniform(size=(3, n)) - t12 = (-np.log(1 - u[0]) / (rate[0] * np.exp(beta[0] * z1))) ** (1 / rate[1]) - t13 = (-np.log(1 - u[1]) / (rate[2] * np.exp(beta[1] * z1))) ** (1 / rate[3]) - - first_event = np.minimum(t12, t13) - censored = first_event >= c - - status = (~censored).astype(int) - transition = np.where(censored, np.nan, np.where(t12 <= t13, 1, 2)) - stop = np.where(censored, c, first_event) - - return pd.DataFrame( - { - "id": np.arange(1, n + 1), - "start": np.zeros(n), - "stop": stop, - "status": status, - "X0": z1, - "transition": transition, - } - ) + validate_gen_cmm_inputs(n, model_cens, cens_par, beta, covar, rate) + + rfunc = runifcens if model_cens == "uniform" else rexpocens + rows = [] + + for k in range(n): + z1 = np.random.uniform(0, covar) + c = rfunc(1, cens_par)[0] + events = generate_event_times(z1, beta, rate) + + t12, t13, t23 = events["t12"], events["t13"], events["t23"] + min_event_time = min(t12, t13, c) + + if min_event_time < c: + if t12 <= t13: + transition = 1 # 1 -> 2 + rows.append([k + 1, 0, t12, 1, z1, transition]) + else: + transition = 2 # 1 -> 3 + rows.append([k + 1, 0, t13, 1, z1, transition]) + else: + # Censored before any event + rows.append([k + 1, 0, c, 0, z1, np.nan]) + + return pd.DataFrame(rows, columns=["id", "start", "stop", "status", "covariate", "transition"]) + diff --git a/gen_surv/competing_risks.py b/gen_surv/competing_risks.py deleted file mode 100644 index f9eb74c..0000000 --- a/gen_surv/competing_risks.py +++ /dev/null @@ -1,697 +0,0 @@ -""" -Competing Risks models for survival data simulation. - -This module provides functions to generate survival data with -competing risks under different hazard specifications. -""" - -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union - -import numpy as np -import pandas as pd - -from .censoring import rexpocens, runifcens -from .validation import ( - NumericSequenceError, - ParameterError, - ensure_censoring_model, - ensure_in_choices, - ensure_numeric_sequence, - ensure_positive, - ensure_positive_int, - ensure_positive_sequence, - ensure_probability, - ensure_sequence_length, -) - -if TYPE_CHECKING: # pragma: no cover - used only for type hints - from matplotlib.axes import Axes - from matplotlib.figure import Figure - - -def _prepare_covariates( - rng: np.random.Generator, - n: int, - n_risks: int, - betas: Optional[Union[List[List[float]], np.ndarray]], - covariate_dist: Literal["normal", "uniform", "binary"], - covariate_params: Optional[Dict[str, float]], -) -> tuple[np.ndarray, np.ndarray, int]: - """Generate covariates and validate associated parameters. - - Returns - ------- - betas : ndarray - Coefficient matrix of shape ``(n_risks, n_covariates)``. - X : ndarray - Generated covariate matrix of shape ``(n, n_covariates)``. - n_covariates : int - Number of covariates. - """ - - ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"}) - n_covariates = 2 - - params: Dict[str, float] - if covariate_params is None: - if covariate_dist == "normal": - params = {"mean": 0.0, "std": 1.0} - elif covariate_dist == "uniform": - params = {"low": 0.0, "high": 1.0} - else: - params = {"p": 0.5} - else: - params = dict(covariate_params) - if covariate_dist == "normal": - mean = params.get("mean") - std = params.get("std") - if mean is None or std is None: - raise ParameterError( - "covariate_params", params, "must include 'mean' and 'std'" - ) - ensure_positive(std, "covariate_params['std']") - mean_f = float(mean) - std_f = float(std) - elif covariate_dist == "uniform": - low = params.get("low") - high = params.get("high") - if low is None or high is None: - raise ParameterError( - "covariate_params", params, "must include 'low' and 'high'" - ) - low_f = float(low) - high_f = float(high) - if high_f <= low_f: - raise ParameterError( - "covariate_params['high']", high_f, "must be greater than 'low'" - ) - else: # binary - p = params.get("p") - if p is None: - raise ParameterError("covariate_params", params, "must include 'p'") - p_f = float(p) - ensure_probability(p_f, "covariate_params['p']") - - if betas is None: - betas_arr = rng.normal(0, 0.5, size=(n_risks, n_covariates)) - else: - try: - betas_arr = np.asarray(betas, dtype=float) - except (TypeError, ValueError) as exc: - raise NumericSequenceError("betas", betas) from exc - ensure_sequence_length(betas_arr, n_risks, "betas") - for j in range(n_risks): - ensure_numeric_sequence(betas_arr[j], f"betas[{j}]") - nonfinite = np.where(~np.isfinite(betas_arr[j]))[0] - if nonfinite.size: - idx = int(nonfinite[0]) - raise NumericSequenceError(f"betas[{j}]", betas_arr[j][idx], idx) - n_covariates = betas_arr.shape[1] - - if covariate_dist == "normal": - X = rng.normal(mean_f, std_f, size=(n, n_covariates)) - elif covariate_dist == "uniform": - X = rng.uniform(low_f, high_f, size=(n, n_covariates)) - else: # binary - X = rng.binomial(1, p_f, size=(n, n_covariates)) - - return betas_arr, X, n_covariates - - -def gen_competing_risks( - n: int, - n_risks: int = 2, - baseline_hazards: Optional[Union[List[float], np.ndarray]] = None, - betas: Optional[Union[List[List[float]], np.ndarray]] = None, - covariate_dist: Literal["normal", "uniform", "binary"] = "normal", - covariate_params: Optional[Dict[str, float]] = None, - max_time: Optional[float] = 10.0, - model_cens: Literal["uniform", "exponential"] = "uniform", - cens_par: float = 5.0, - seed: Optional[int] = None, -) -> pd.DataFrame: - """ - Generate survival data with competing risks. - - Parameters - ---------- - n : int - Number of subjects. - n_risks : int, default=2 - Number of competing risks. - baseline_hazards : list of float or array, optional - Baseline hazard rates for each risk. If None, uses [0.5, 0.3, ...] - with decreasing values for subsequent risks. - betas : list of list of float or array, optional - Coefficients for covariates, one list per risk. - Shape should be (n_risks, n_covariates). - If None, generates random coefficients. - covariate_dist : {"normal", "uniform", "binary"}, default="normal" - Distribution to generate covariates from. - covariate_params : dict, optional - Parameters for covariate distribution: - - "normal": {"mean": float, "std": float} - - "uniform": {"low": float, "high": float} - - "binary": {"p": float} - If None, uses defaults based on distribution. - max_time : float, optional, default=10.0 - Maximum simulation time. Set to None for no limit. - model_cens : {"uniform", "exponential"}, default="uniform" - Censoring mechanism. - cens_par : float, default=5.0 - Parameter for censoring distribution. - seed : int, optional - Random seed for reproducibility. - - Returns - ------- - pd.DataFrame - DataFrame with columns: - - "id": Subject identifier - - "time": Time to event or censoring - - "status": Event indicator (0=censored, 1,2,...=competing events) - - "X0", "X1", ...: Covariates - - Examples - -------- - >>> from gen_surv.competing_risks import gen_competing_risks - >>> - >>> # Simple example with 2 competing risks - >>> df = gen_competing_risks( - ... n=100, - ... n_risks=2, - ... baseline_hazards=[0.5, 0.3], - ... betas=[[0.8, -0.5], [0.2, 0.7]], - ... seed=42 - ... ) - >>> - >>> # Distribution of event types - >>> df["status"].value_counts() - """ - rng = np.random.default_rng(seed) - - ensure_positive_int(n, "n") - ensure_positive_int(n_risks, "n_risks") - ensure_censoring_model(model_cens) - ensure_positive(cens_par, "cens_par") - if max_time is not None: - ensure_positive(max_time, "max_time") - - # Set default baseline hazards if not provided - if baseline_hazards is None: - baseline_hazards = np.array([0.5 / (i + 1) for i in range(n_risks)]) - else: - baseline_hazards = np.asarray(baseline_hazards, dtype=float) - ensure_sequence_length(baseline_hazards, n_risks, "baseline_hazards") - ensure_positive_sequence(baseline_hazards, "baseline_hazards") - - betas, X, n_covariates = _prepare_covariates( - rng, n, n_risks, betas, covariate_dist, covariate_params - ) - - # Calculate linear predictors for each risk - linear_predictors = np.zeros((n, n_risks)) - for j in range(n_risks): - linear_predictors[:, j] = X @ betas[j] - - # Calculate hazard rates - hazard_rates = np.zeros_like(linear_predictors) - for j in range(n_risks): - hazard_rates[:, j] = baseline_hazards[j] * np.exp(linear_predictors[:, j]) - - # Generate event times for each risk - event_times = np.zeros((n, n_risks)) - for j in range(n_risks): - # Use exponential distribution with rate = hazard - event_times[:, j] = rng.exponential(1 / hazard_rates[:, j]) - - # Generate censoring times - rfunc = runifcens if model_cens == "uniform" else rexpocens - cens_times = rfunc(n, cens_par, rng) - - # Find the minimum time for each subject (first event or censoring) - min_event_times = np.min(event_times, axis=1) - observed_times = np.minimum(min_event_times, cens_times) - - # Determine event type (0 = censored, 1...n_risks = event type) - status = np.zeros(n, dtype=int) - for i in range(n): - if min_event_times[i] <= cens_times[i]: - # Find which risk occurred first - risk_index = np.argmin(event_times[i]) - status[i] = risk_index + 1 # 1-based indexing for event types - - # Ensure at least two event types are present for small n - if len(np.unique(status)) <= 1 and n_risks > 1: - status[0] = 1 - if n > 1: - status[1] = 2 - - # Cap times at max_time if specified - if max_time is not None: - over_max = observed_times > max_time - observed_times[over_max] = max_time - status[over_max] = 0 # Censored if beyond max_time - - # Create DataFrame - data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) - - # Add covariates - for j in range(n_covariates): - data[f"X{j}"] = X[:, j] - - return data - - -def gen_competing_risks_weibull( - n: int, - n_risks: int = 2, - shape_params: Optional[Union[List[float], np.ndarray]] = None, - scale_params: Optional[Union[List[float], np.ndarray]] = None, - betas: Optional[Union[List[List[float]], np.ndarray]] = None, - covariate_dist: Literal["normal", "uniform", "binary"] = "normal", - covariate_params: Optional[Dict[str, float]] = None, - max_time: Optional[float] = 10.0, - model_cens: Literal["uniform", "exponential"] = "uniform", - cens_par: float = 5.0, - seed: Optional[int] = None, -) -> pd.DataFrame: - """ - Generate survival data with competing risks using Weibull hazards. - - Parameters - ---------- - n : int - Number of subjects. - n_risks : int, default=2 - Number of competing risks. - shape_params : list of float or array, optional - Shape parameters for Weibull distribution, one per risk. - If None, uses [1.2, 0.8, ...] alternating values. - scale_params : list of float or array, optional - Scale parameters for Weibull distribution, one per risk. - If None, uses [2.0, 3.0, ...] increasing values. - betas : list of list of float or array, optional - Coefficients for covariates, one list per risk. - Shape should be (n_risks, n_covariates). - If None, generates random coefficients. - covariate_dist : {"normal", "uniform", "binary"}, default="normal" - Distribution to generate covariates from. - covariate_params : dict, optional - Parameters for covariate distribution: - - "normal": {"mean": float, "std": float} - - "uniform": {"low": float, "high": float} - - "binary": {"p": float} - If None, uses defaults based on distribution. - max_time : float, optional, default=10.0 - Maximum simulation time. Set to None for no limit. - model_cens : {"uniform", "exponential"}, default="uniform" - Censoring mechanism. - cens_par : float, default=5.0 - Parameter for censoring distribution. - seed : int, optional - Random seed for reproducibility. - - Returns - ------- - pd.DataFrame - DataFrame with columns: - - "id": Subject identifier - - "time": Time to event or censoring - - "status": Event indicator (0=censored, 1,2,...=competing events) - - "X0", "X1", ...: Covariates - - Examples - -------- - >>> from gen_surv.competing_risks import gen_competing_risks_weibull - >>> - >>> # Example with 2 competing risks with different shapes - >>> df = gen_competing_risks_weibull( - ... n=100, - ... n_risks=2, - ... shape_params=[0.8, 1.5], # Decreasing vs increasing hazard - ... scale_params=[2.0, 3.0], - ... betas=[[0.8, -0.5], [0.2, 0.7]], - ... seed=42 - ... ) - """ - rng = np.random.default_rng(seed) - - ensure_positive_int(n, "n") - ensure_positive_int(n_risks, "n_risks") - ensure_censoring_model(model_cens) - ensure_positive(cens_par, "cens_par") - if max_time is not None: - ensure_positive(max_time, "max_time") - - # Set default shape and scale parameters if not provided - if shape_params is None: - shape_params = np.array([1.2 if i % 2 == 0 else 0.8 for i in range(n_risks)]) - else: - shape_params = np.asarray(shape_params, dtype=float) - ensure_sequence_length(shape_params, n_risks, "shape_params") - ensure_positive_sequence(shape_params, "shape_params") - - if scale_params is None: - scale_params = np.array([2.0 + i for i in range(n_risks)]) - else: - scale_params = np.asarray(scale_params, dtype=float) - ensure_sequence_length(scale_params, n_risks, "scale_params") - ensure_positive_sequence(scale_params, "scale_params") - - betas, X, n_covariates = _prepare_covariates( - rng, n, n_risks, betas, covariate_dist, covariate_params - ) - - # Calculate linear predictors for each risk - linear_predictors = np.zeros((n, n_risks)) - for j in range(n_risks): - linear_predictors[:, j] = X @ betas[j] - - # Generate event times for each risk using Weibull distribution - event_times = np.zeros((n, n_risks)) - for j in range(n_risks): - # Adjust the scale parameter using the linear predictor - adjusted_scale = scale_params[j] * np.exp( - -linear_predictors[:, j] / shape_params[j] - ) - - # Generate random uniform between 0 and 1 - u = rng.uniform(0, 1, size=n) - - # Convert to Weibull using inverse CDF: t = scale * (-log(1-u))^(1/shape) - event_times[:, j] = adjusted_scale * (-np.log(1 - u)) ** (1 / shape_params[j]) - - # Generate censoring times - rfunc = runifcens if model_cens == "uniform" else rexpocens - cens_times = rfunc(n, cens_par, rng) - - # Find the minimum time for each subject (first event or censoring) - min_event_times = np.min(event_times, axis=1) - observed_times = np.minimum(min_event_times, cens_times) - - # Determine event type (0 = censored, 1...n_risks = event type) - status = np.zeros(n, dtype=int) - for i in range(n): - if min_event_times[i] <= cens_times[i]: - # Find which risk occurred first - risk_index = np.argmin(event_times[i]) - status[i] = risk_index + 1 # 1-based indexing for event types - if len(np.unique(status)) <= 1 and n_risks > 1: - status[0] = 1 - if n > 1: - status[1] = 2 - - # Cap times at max_time if specified - if max_time is not None: - over_max = observed_times > max_time - observed_times[over_max] = max_time - status[over_max] = 0 # Censored if beyond max_time - - # Create DataFrame - data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) - - # Add covariates - for j in range(n_covariates): - data[f"X{j}"] = X[:, j] - - return data - - -def cause_specific_cumulative_incidence( - data: pd.DataFrame, - time_points: Union[List[float], np.ndarray], - time_col: str = "time", - status_col: str = "status", - cause: int = 1, -) -> pd.DataFrame: - """ - Calculate the cause-specific cumulative incidence function at specified time points. - - Parameters - ---------- - data : pd.DataFrame - DataFrame with competing risks data. - time_points : list of float or array - Time points at which to calculate the cumulative incidence. - time_col : str, default="time" - Name of the time column. - status_col : str, default="status" - Name of the status column (0=censored, 1,2,...=competing events). - cause : int, default=1 - The cause/event type for which to calculate the incidence. - - Returns - ------- - pd.DataFrame - DataFrame with time points and corresponding cumulative incidence values. - - Notes - ----- - The cumulative incidence function for cause j is defined as: - F_j(t) = P(T <= t, cause = j) - - This is the probability of experiencing the event of type j before time t. - """ - # Validate the cause value - unique_causes = set(data[status_col].unique()) - {0} # Exclude censoring - if cause not in unique_causes: - raise ParameterError( - "cause", cause, f"not found in the data. Available causes: {unique_causes}" - ) - - sorted_data = data.sort_values(by=time_col).copy() - times = sorted_data[time_col].to_numpy() - status = sorted_data[status_col].to_numpy() - - unique_times, idx = np.unique(times, return_index=True) - counts = np.diff(np.append(idx, len(times))) - at_risk = len(times) - idx - - d_all = np.zeros_like(unique_times, dtype=int) - d_cause = np.zeros_like(unique_times, dtype=int) - inverse = np.repeat(np.arange(len(unique_times)), counts) - for i, s in enumerate(status): - if s > 0: - d_all[inverse[i]] += 1 - if s == cause: - d_cause[inverse[i]] += 1 - - surv = 1.0 - cif_vals = np.zeros_like(unique_times, dtype=float) - ci = 0.0 - for i, t in enumerate(unique_times): - prev_surv = surv - surv *= 1 - d_all[i] / at_risk[i] - ci += prev_surv * d_cause[i] / at_risk[i] - cif_vals[i] = ci - - result = [] - for t in time_points: - if t <= 0: - result.append({"time": t, "incidence": 0.0}) - elif t >= unique_times[-1]: - result.append({"time": t, "incidence": cif_vals[-1]}) - else: - idx = np.searchsorted(unique_times, t, side="right") - 1 - result.append({"time": t, "incidence": cif_vals[idx]}) - - return pd.DataFrame(result) - - -def competing_risks_summary( - data: pd.DataFrame, - time_col: str = "time", - status_col: str = "status", - covariate_cols: list[str] | None = None, -) -> dict[str, Any]: - """ - Provide a summary of a competing risks dataset. - - Parameters - ---------- - data : pd.DataFrame - DataFrame with competing risks data. - time_col : str, default="time" - Name of the time column. - status_col : str, default="status" - Name of the status column (0=censored, 1,2,...=competing events). - covariate_cols : list of str, optional - List of covariate columns to include in the summary. - If None, all columns except time_col and status_col are considered. - - Returns - ------- - Dict[str, Any] - Dictionary with summary statistics. - - Examples - -------- - >>> from gen_surv.competing_risks import gen_competing_risks, competing_risks_summary - >>> - >>> # Generate data - >>> df = gen_competing_risks(n=100, n_risks=3, seed=42) - >>> - >>> # Get summary - >>> summary = competing_risks_summary(df) - >>> print(f"Number of events by cause: {summary['events_by_cause']}") - >>> print(f"Median time to first event: {summary['median_time']}") - """ - # Determine covariate columns if not provided - if covariate_cols is None: - covariate_cols = [ - col for col in data.columns if col not in [time_col, status_col, "id"] - ] - - # Basic counts - n_subjects = len(data) - n_events = (data[status_col] > 0).sum() - n_censored = n_subjects - n_events - censoring_rate = n_censored / n_subjects - - # Events by cause - causes = sorted(data[data[status_col] > 0][status_col].unique()) - events_by_cause = {} - for cause in causes: - n_cause = (data[status_col] == cause).sum() - events_by_cause[int(cause)] = { - "count": int(n_cause), - "proportion": float(n_cause / n_subjects), - "proportion_of_events": float(n_cause / n_events) if n_events > 0 else 0, - } - - # Time statistics - time_stats = { - "min": float(data[time_col].min()), - "max": float(data[time_col].max()), - "median": float(data[time_col].median()), - "mean": float(data[time_col].mean()), - } - - # Median time to each type of event - median_time_by_cause = {} - for cause in causes: - cause_times = data[data[status_col] == cause][time_col] - if not cause_times.empty: - median_time_by_cause[int(cause)] = float(cause_times.median()) - - # Covariate statistics - covariate_stats: dict[str, dict[str, float | int | dict[str, float]]] = {} - for col in covariate_cols: - col_data = data[col] - - # Check if numeric - if pd.api.types.is_numeric_dtype(col_data): - covariate_stats[col] = { - "mean": float(col_data.mean()), - "median": float(col_data.median()), - "std": float(col_data.std()), - "min": float(col_data.min()), - "max": float(col_data.max()), - } - else: - # Categorical statistics - value_counts = col_data.value_counts(normalize=True).to_dict() - covariate_stats[col] = { - "categories": len(value_counts), - "distribution": {str(k): float(v) for k, v in value_counts.items()}, - } - - # Compile final summary - summary = { - "n_subjects": n_subjects, - "n_events": n_events, - "n_censored": n_censored, - "censoring_rate": censoring_rate, - "n_causes": len(causes), - "causes": list(map(int, causes)), - "events_by_cause": events_by_cause, - "time_stats": time_stats, - "median_time_by_cause": median_time_by_cause, - "covariate_stats": covariate_stats, - } - - return summary - - -def plot_cause_specific_hazards( - data: pd.DataFrame, - time_points: np.ndarray | None = None, - time_col: str = "time", - status_col: str = "status", - bandwidth: float = 0.5, - figsize: tuple[float, float] = (10, 6), -) -> tuple["Figure", "Axes"]: - """ - Plot cause-specific hazard functions. - - Parameters - ---------- - data : pd.DataFrame - DataFrame with competing risks data. - time_points : array, optional - Time points at which to estimate hazards. - If None, uses 100 equally spaced points from 0 to max time. - time_col : str, default="time" - Name of the time column. - status_col : str, default="status" - Name of the status column (0=censored, 1,2,...=competing events). - bandwidth : float, default=0.5 - Bandwidth for kernel density estimation. - figsize : tuple, default=(10, 6) - Figure size (width, height) in inches. - - Returns - ------- - tuple - Figure and axes objects. - - Notes - ----- - This function requires matplotlib and scipy. - """ - try: - import matplotlib.pyplot as plt - from scipy.stats import gaussian_kde - except ImportError: - raise ImportError( - "This function requires matplotlib and scipy. " - "Install them with: pip install matplotlib scipy" - ) - - # Determine time points if not provided - if time_points is None: - max_time = data[time_col].max() - time_points = np.linspace(0, max_time, 100) - - # Get unique causes (excluding censoring) - causes = sorted([c for c in data[status_col].unique() if c > 0]) - - # Create plot - fig, ax = plt.subplots(figsize=figsize) - - times_sorted = np.sort(data[time_col].to_numpy()) - total = len(times_sorted) - - # Plot hazard for each cause - for cause in causes: - cause_data = data[data[status_col] == cause] - if len(cause_data) < 5: - continue - - kde = gaussian_kde(cause_data[time_col], bw_method=bandwidth) - - at_risk = total - np.searchsorted(times_sorted, time_points, side="left") - at_risk = np.maximum(at_risk, 1) - - hazard = kde(time_points) * total / at_risk - ax.plot(time_points, hazard, label=f"Cause {cause}") - - # Format plot - ax.set_xlabel("Time") - ax.set_ylabel("Hazard Rate") - ax.set_title("Cause-Specific Hazard Functions") - ax.legend() - ax.grid(alpha=0.3) - - return fig, ax diff --git a/gen_surv/cphm.py b/gen_surv/cphm.py index 0f09c6d..97f52a6 100644 --- a/gen_surv/cphm.py +++ b/gen_surv/cphm.py @@ -1,59 +1,28 @@ -""" -Cox Proportional Hazards Model (CPHM) data generation. - -This module provides functions to generate survival data following the -Cox Proportional Hazards Model with various censoring mechanisms. -""" - -from typing import Literal - import numpy as np import pandas as pd -from numpy.typing import NDArray - -from gen_surv.censoring import CensoringFunc, rexpocens, runifcens -from gen_surv.validation import validate_gen_cphm_inputs +from gen_surv.validate import validate_gen_cphm_inputs +from gen_surv.censoring import runifcens, rexpocens - -def generate_cphm_data( - n: int, - rfunc: CensoringFunc, - cens_par: float, - beta: float, - covariate_range: float, - seed: int | None = None, -) -> NDArray[np.float64]: +def generate_cphm_data(n, rfunc, cens_par, beta, covariate_range): """ Generate data from a Cox Proportional Hazards Model (CPHM). - Parameters - ---------- - n : int - Number of samples to generate. - rfunc : callable - Function to generate censoring times, must accept (size, cens_par). - cens_par : float - Parameter passed to the censoring function. - beta : float - Coefficient for the covariate. - covariate_range : float - Range for the covariate (uniformly sampled from [0, covariate_range]). - seed : int, optional - Random seed for reproducibility. + Parameters: + - n (int): Number of samples to generate. + - rfunc (callable): Function to generate censoring times, must accept (size, cens_par). + - cens_par (float): Parameter passed to the censoring function. + - beta (float): Coefficient for the covariate. + - covar (float): Range for the covariate (uniformly sampled from [0, covar]). - Returns - ------- - NDArray[np.float64] - Array with shape ``(n, 3)``: ``[time, status, X0]`` + Returns: + - np.ndarray: Array with shape (n, 3): [time, status, covariate] """ - rng = np.random.default_rng(seed) - - data: NDArray[np.float64] = np.zeros((n, 3), dtype=float) + data = np.zeros((n, 3)) for k in range(n): - z = rng.uniform(0, covariate_range) - c = rfunc(1, cens_par, rng)[0] - x = rng.exponential(scale=1 / np.exp(beta * z)) + z = np.random.uniform(0, covariate_range) + c = rfunc(1, cens_par)[0] + x = np.random.exponential(scale=1 / np.exp(beta * z)) time = min(x, c) status = int(x <= c) @@ -63,53 +32,27 @@ def generate_cphm_data( return data -def gen_cphm( - n: int, - model_cens: Literal["uniform", "exponential"], - cens_par: float, - beta: float, - covariate_range: float, - seed: int | None = None, -) -> pd.DataFrame: +def gen_cphm(n: int, model_cens: str, cens_par: float, beta: float, covar: float) -> pd.DataFrame: """ - Generate survival data following a Cox Proportional Hazards Model. + Convenience wrapper to generate CPHM survival data. - Parameters - ---------- - n : int - Number of observations. - model_cens : {"uniform", "exponential"} - Type of censoring mechanism. - cens_par : float - Parameter for the censoring model. - beta : float - Coefficient for the covariate. - covariate_range : float - Upper bound for the covariate values (uniform between 0 and covariate_range). - seed : int, optional - Random seed for reproducibility. + Parameters: + - n (int): Number of observations. + - model_cens (str): "uniform" or "exponential". + - cens_par (float): Parameter for the censoring model. + - beta (float): Coefficient for the covariate. + - covar (float): Covariate range (uniform between 0 and covar). - Returns - ------- - pd.DataFrame - DataFrame with columns ["time", "status", "X0"] - - time: observed event or censoring time - - status: event indicator (1=event, 0=censored) - - X0: predictor variable - - Examples - -------- - >>> from gen_surv.cphm import gen_cphm - >>> df = gen_cphm(n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> df.head() - time status X0 - 0 0.23 1.0 1.42 - 1 0.78 0.0 0.89 - ... + Returns: + - pd.DataFrame: Columns are ["time", "status", "covariate"] """ - validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) + validate_gen_cphm_inputs(n, model_cens, cens_par, covar) + + rfunc = { + "uniform": runifcens, + "exponential": rexpocens + }[model_cens] - rfunc = {"uniform": runifcens, "exponential": rexpocens}[model_cens] + data = generate_cphm_data(n, rfunc, cens_par, beta, covar) - data = generate_cphm_data(n, rfunc, cens_par, beta, covariate_range, seed) - return pd.DataFrame(data, columns=["time", "status", "X0"]) + return pd.DataFrame(data, columns=["time", "status", "covariate"]) diff --git a/gen_surv/export.py b/gen_surv/export.py deleted file mode 100644 index 4a23fc2..0000000 --- a/gen_surv/export.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Data export utilities for gen_surv. - -This module provides helper functions to save generated -survival datasets in various formats. -""" - -from __future__ import annotations - -import os -from typing import Optional - -import pandas as pd -import pyreadr - -from .validation import ensure_in_choices - - -def export_dataset(df: pd.DataFrame, path: str, fmt: Optional[str] = None) -> None: - """Save a DataFrame to disk. - - Parameters - ---------- - df : pd.DataFrame - DataFrame containing survival data. - path : str - File path to write to. The extension is used to infer the format - when ``fmt`` is ``None``. - fmt : {"csv", "json", "feather", "rds"}, optional - Format to use. If omitted, inferred from ``path``. - - Raises - ------ - ChoiceError - If the format is not one of the supported types. - """ - if fmt is None: - fmt = os.path.splitext(path)[1].lstrip(".").lower() - - ensure_in_choices(fmt, "fmt", {"csv", "json", "feather", "ft", "rds"}) - - if fmt == "csv": - df.to_csv(path, index=False) - elif fmt == "json": - df.to_json(path, orient="table") - elif fmt in {"feather", "ft"}: - df.reset_index(drop=True).to_feather(path) - elif fmt == "rds": - pyreadr.write_rds(path, df.reset_index(drop=True)) diff --git a/gen_surv/integration.py b/gen_surv/integration.py deleted file mode 100644 index 6967086..0000000 --- a/gen_surv/integration.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -import numpy as np -import pandas as pd -from numpy.typing import NDArray - - -def to_sksurv( - df: pd.DataFrame, time_col: str = "time", event_col: str = "status" -) -> NDArray[np.void]: - """Convert a DataFrame to a scikit-survival structured array. - - Parameters - ---------- - df : pd.DataFrame - DataFrame containing survival data. - time_col : str, default "time" - Column storing durations. - event_col : str, default "status" - Column storing event indicators (1=event, 0=censored). - - Returns - ------- - numpy.ndarray - Structured array suitable for scikit-survival estimators. - - Notes - ----- - The ``sksurv`` package is imported lazily inside the function. It must be - installed separately, for instance with ``pip install scikit-survival``. - """ - - try: - from sksurv.util import Surv - except ImportError as exc: # pragma: no cover - optional dependency - raise ImportError("scikit-survival is required for this feature.") from exc - - return Surv.from_dataframe(event_col, time_col, df) diff --git a/gen_surv/interface.py b/gen_surv/interface.py index 88d1c28..549ad45 100644 --- a/gen_surv/interface.py +++ b/gen_surv/interface.py @@ -3,90 +3,41 @@ Example: >>> from gen_surv import generate - >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) + >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) """ -from collections.abc import Callable -from typing import Dict, Literal - +from typing import Any import pandas as pd -from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull -from gen_surv.cmm import gen_cmm -from gen_surv.competing_risks import gen_competing_risks, gen_competing_risks_weibull from gen_surv.cphm import gen_cphm -from gen_surv.mixture import gen_mixture_cure -from gen_surv.piecewise import gen_piecewise_exponential +from gen_surv.cmm import gen_cmm from gen_surv.tdcm import gen_tdcm from gen_surv.thmm import gen_thmm - -from .validation import ensure_in_choices - -# Type definitions for model names -ModelType = Literal[ - "cphm", - "cmm", - "tdcm", - "thmm", - "aft_ln", - "aft_weibull", - "aft_log_logistic", - "competing_risks", - "competing_risks_weibull", - "mixture_cure", - "piecewise_exponential", -] +from gen_surv.aft import gen_aft_log_normal -# Interface for generator callables -DataGenerator = Callable[..., pd.DataFrame] - - -# Map model names to their generator functions -_model_map: Dict[ModelType, DataGenerator] = { +_model_map = { "cphm": gen_cphm, "cmm": gen_cmm, "tdcm": gen_tdcm, "thmm": gen_thmm, "aft_ln": gen_aft_log_normal, - "aft_weibull": gen_aft_weibull, - "aft_log_logistic": gen_aft_log_logistic, - "competing_risks": gen_competing_risks, - "competing_risks_weibull": gen_competing_risks_weibull, - "mixture_cure": gen_mixture_cure, - "piecewise_exponential": gen_piecewise_exponential, } -def generate(model: ModelType, **kwargs: object) -> pd.DataFrame: +def generate(model: str, **kwargs: Any) -> pd.DataFrame: """Generate survival data from a specific model. Args: model: Name of the generator to run. Must be one of ``cphm``, ``cmm``, - ``tdcm``, ``thmm``, ``aft_ln``, ``aft_weibull``, ``aft_log_logistic``, - ``competing_risks``, ``competing_risks_weibull``, ``mixture_cure``, - or ``piecewise_exponential``. - **kwargs: Arguments forwarded to the chosen generator. These vary by model. - - - cphm: n, model_cens, cens_par, beta, covariate_range - - cmm: n, model_cens, cens_par, beta, covariate_range, rate - - tdcm: n, dist, corr, dist_par, model_cens, cens_par, beta, lam - - thmm: n, model_cens, cens_par, beta, covariate_range, rate - - aft_ln: n, beta, sigma, model_cens, cens_par, seed - - aft_weibull: n, beta, shape, scale, model_cens, cens_par, seed - - aft_log_logistic: n, beta, shape, scale, model_cens, cens_par, seed - - competing_risks: n, n_risks, baseline_hazards, betas, covariate_dist, etc. - - competing_risks_weibull: n, n_risks, shape_params, scale_params, betas, etc. - - mixture_cure: n, cure_fraction, baseline_hazard, betas_survival, - betas_cure, etc. - - piecewise_exponential: n, breakpoints, hazard_rates, betas, etc. + ``tdcm``, ``thmm`` or ``aft_ln``. + **kwargs: Arguments forwarded to the chosen generator. Returns: - pd.DataFrame: Simulated survival data with columns specific to the chosen model. - All models include time/duration and status columns. - - Raises: - ChoiceError: If an unknown model name is provided. + pd.DataFrame: Simulated survival data. """ - ensure_in_choices(model, "model", _model_map.keys()) + model = model.lower() + if model not in _model_map: + raise ValueError(f"Unknown model '{model}'. Choose from {list(_model_map.keys())}.") + return _model_map[model](**kwargs) diff --git a/gen_surv/mixture.py b/gen_surv/mixture.py deleted file mode 100644 index 2eee6eb..0000000 --- a/gen_surv/mixture.py +++ /dev/null @@ -1,305 +0,0 @@ -""" -Mixture Cure Models for survival data simulation. - -This module provides functions to generate survival data with a cure fraction, -i.e., a proportion of subjects who are immune to the event of interest. -""" - -from typing import Literal - -import numpy as np -import pandas as pd -from numpy.random import Generator -from numpy.typing import NDArray - -_TAIL_FRACTION: float = 0.1 -_SMOOTH_MIN_TAIL: int = 3 - -from ._covariates import generate_covariates, set_covariate_params -from .censoring import rexpocens, runifcens -from .validation import ( - LengthError, - ParameterError, - ensure_censoring_model, - ensure_in_choices, - ensure_numeric_sequence, - ensure_positive, - ensure_positive_int, -) - - -def _prepare_betas( - betas_survival: list[float] | None, - betas_cure: list[float] | None, - n_covariates: int, - rng: Generator, -) -> tuple[NDArray[np.float64], NDArray[np.float64], int]: - if betas_survival is None: - betas_survival_arr = rng.normal(0, 0.5, size=n_covariates) - else: - ensure_numeric_sequence(betas_survival, "betas_survival") - betas_survival_arr = np.asarray(betas_survival, dtype=float) - n_covariates = len(betas_survival_arr) - - if betas_cure is None: - betas_cure_arr = rng.normal(0, 0.5, size=n_covariates) - else: - ensure_numeric_sequence(betas_cure, "betas_cure") - betas_cure_arr = np.asarray(betas_cure, dtype=float) - if len(betas_cure_arr) != n_covariates: - raise LengthError("betas_cure", len(betas_cure_arr), n_covariates) - - return betas_survival_arr, betas_cure_arr, n_covariates - - -def _cure_status( - lp_cure: NDArray[np.float64], cure_fraction: float, rng: Generator -) -> NDArray[np.int64]: - if not 0 < cure_fraction < 1: - raise ParameterError("cure_fraction", cure_fraction, "must be between 0 and 1") - cure_probs = 1 / ( - 1 + np.exp(-(np.log(cure_fraction / (1 - cure_fraction)) + lp_cure)) - ) - return rng.binomial(1, cure_probs).astype(np.int64) - - -def _survival_times( - cured: NDArray[np.int64], - lp_survival: NDArray[np.float64], - baseline_hazard: float, - max_time: float | None, - rng: Generator, -) -> NDArray[np.float64]: - ensure_positive(baseline_hazard, "baseline_hazard") - if max_time is not None: - ensure_positive(max_time, "max_time") - n = cured.size - times = np.zeros(n, dtype=float) - non_cured = cured == 0 - adjusted_hazard = baseline_hazard * np.exp(lp_survival[non_cured]) - times[non_cured] = rng.exponential(scale=1 / adjusted_hazard) - if max_time is not None: - times[~non_cured] = max_time * 100 - else: - times[~non_cured] = np.inf - return times - - -def _apply_censoring( - survival_times: NDArray[np.float64], - model_cens: str, - cens_par: float, - max_time: float | None, - rng: Generator, -) -> tuple[NDArray[np.float64], NDArray[np.int64]]: - ensure_censoring_model(model_cens) - ensure_positive(cens_par, "cens_par") - if max_time is not None: - ensure_positive(max_time, "max_time") - rfunc = runifcens if model_cens == "uniform" else rexpocens - cens_times = rfunc(len(survival_times), cens_par, rng) - observed = np.minimum(survival_times, cens_times) - status = (survival_times <= cens_times).astype(int) - if max_time is not None: - over_max = observed > max_time - observed[over_max] = max_time - status[over_max] = 0 - return observed, status - - -def gen_mixture_cure( - n: int, - cure_fraction: float, - baseline_hazard: float = 0.5, - betas_survival: list[float] | None = None, - betas_cure: list[float] | None = None, - n_covariates: int = 2, - covariate_dist: Literal["normal", "uniform", "binary"] = "normal", - covariate_params: dict[str, float] | None = None, - model_cens: Literal["uniform", "exponential"] = "uniform", - cens_par: float = 5.0, - max_time: float | None = 10.0, - seed: int | None = None, -) -> pd.DataFrame: - """ - Generate survival data with a cure fraction using a mixture cure model. - - Parameters - ---------- - n : int - Number of subjects. - cure_fraction : float - Baseline probability of being cured (immune to the event). - Should be between 0 and 1. - baseline_hazard : float, default=0.5 - Baseline hazard rate for the non-cured population. - betas_survival : list of float, optional - Coefficients for covariates in the survival component. - If None, generates random coefficients. - betas_cure : list of float, optional - Coefficients for covariates in the cure component. - If None, generates random coefficients. - n_covariates : int, default=2 - Number of covariates to generate if betas is None. - covariate_dist : {"normal", "uniform", "binary"}, default="normal" - Distribution to generate covariates from. - covariate_params : dict, optional - Parameters for covariate distribution: - - "normal": {"mean": float, "std": float} - - "uniform": {"low": float, "high": float} - - "binary": {"p": float} - If None, uses defaults based on distribution. - model_cens : {"uniform", "exponential"}, default="uniform" - Censoring mechanism. - cens_par : float, default=5.0 - Parameter for censoring distribution. - max_time : float, optional, default=10.0 - Maximum simulation time. Set to None for no limit. - seed : int, optional - Random seed for reproducibility. - - Returns - ------- - pd.DataFrame - DataFrame with columns: - - "id": Subject identifier - - "time": Time to event or censoring - - "status": Event indicator (1=event, 0=censored) - - "cured": Indicator of cure status (1=cured, 0=not cured) - - "X0", "X1", ...: Covariates - - Examples - -------- - >>> from gen_surv.mixture import gen_mixture_cure - >>> - >>> # Generate data with 30% baseline cure fraction - >>> df = gen_mixture_cure( - ... n=100, - ... cure_fraction=0.3, - ... betas_survival=[0.8, -0.5], - ... betas_cure=[-0.5, 0.8], - ... seed=42 - ... ) - >>> - >>> # Check cure proportion - >>> print(f"Cured subjects: {df['cured'].mean():.2%}") - """ - rng = np.random.default_rng(seed) - - ensure_positive_int(n, "n") - ensure_positive_int(n_covariates, "n_covariates") - ensure_positive(baseline_hazard, "baseline_hazard") - ensure_positive(cens_par, "cens_par") - if max_time is not None: - ensure_positive(max_time, "max_time") - if not 0 <= cure_fraction <= 1: - raise ParameterError("cure_fraction", cure_fraction, "must be between 0 and 1") - - ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"}) - covariate_params = set_covariate_params(covariate_dist, covariate_params) - betas_survival_arr, betas_cure_arr, n_covariates = _prepare_betas( - betas_survival, betas_cure, n_covariates, rng - ) - X = generate_covariates(n, n_covariates, covariate_dist, covariate_params, rng) - lp_survival = X @ betas_survival_arr - lp_cure = X @ betas_cure_arr - cured = _cure_status(lp_cure, cure_fraction, rng) - survival_times = _survival_times(cured, lp_survival, baseline_hazard, max_time, rng) - - ensure_censoring_model(model_cens) - observed_times, status = _apply_censoring( - survival_times, model_cens, cens_par, max_time, rng - ) - - data = pd.DataFrame( - {"id": np.arange(n), "time": observed_times, "status": status, "cured": cured} - ) - - for j in range(n_covariates): - data[f"X{j}"] = X[:, j] - - return data - - -def cure_fraction_estimate( - data: pd.DataFrame, - time_col: str = "time", - status_col: str = "status", - bandwidth: float = 0.1, -) -> float: - """ - Estimate the cure fraction from observed data using non-parametric methods. - - Parameters - ---------- - data : pd.DataFrame - DataFrame with survival data. - time_col : str, default="time" - Name of the time column. - status_col : str, default="status" - Name of the status column (1=event, 0=censored). - bandwidth : float, default=0.1 - Bandwidth parameter for smoothing the tail of the survival curve. - - Returns - ------- - float - Estimated cure fraction. - - Notes - ----- - This function uses a non-parametric approach to estimate the cure fraction - based on the plateau of the survival curve. It may not be accurate for - small sample sizes or heavy censoring. - """ - if time_col not in data.columns or status_col not in data.columns: - missing = [c for c in (time_col, status_col) if c not in data.columns] - raise ParameterError( - "data", - data.columns.tolist(), - f"missing required column(s): {', '.join(missing)}", - ) - ensure_positive(bandwidth, "bandwidth") - # Sort data by time - sorted_data = data.sort_values(by=time_col).copy() - - # Calculate Kaplan-Meier estimate - times = sorted_data[time_col].values - status = sorted_data[status_col].values - n = len(times) - - if n == 0: - return 0.0 - - # Calculate survival function - survival = np.ones(n) - - for i in range(n): - if i > 0: - survival[i] = survival[i - 1] - - # Count subjects at risk at this time - at_risk = n - i - - if status[i] == 1: # Event - survival[i] *= 1 - 1 / at_risk - - # Estimate cure fraction as the plateau of the survival curve - # Use the last portion of the survival curve if enough data points - tail_size = max(int(n * _TAIL_FRACTION), 1) - tail_survival = survival[-tail_size:] - - # Apply smoothing if there are enough data points - if tail_size > _SMOOTH_MIN_TAIL: - # Use kernel smoothing - weights = np.exp( - -((np.arange(tail_size) - tail_size + 1) ** 2) - / (2 * bandwidth * tail_size) ** 2 - ) - weights = weights / weights.sum() - cure_fraction = float(np.sum(tail_survival * weights)) - else: - # Just use the last survival probability - cure_fraction = float(survival[-1]) - - return cure_fraction diff --git a/gen_surv/piecewise.py b/gen_surv/piecewise.py deleted file mode 100644 index 391701f..0000000 --- a/gen_surv/piecewise.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Piecewise Exponential survival models. - -This module provides functions for generating survival data from piecewise -exponential distributions with time-dependent hazards. -""" - -from typing import Literal - -import numpy as np -import pandas as pd -from numpy.typing import NDArray - -from ._covariates import generate_covariates, set_covariate_params -from .censoring import rexpocens, runifcens -from .validation import ( - ParameterError, - ensure_censoring_model, - ensure_in_choices, - ensure_numeric_sequence, - ensure_positive, - ensure_positive_int, - ensure_positive_sequence, - ensure_sequence_length, -) - - -def _validate_piecewise_params( - breakpoints: list[float], hazard_rates: list[float] -) -> None: - """Validate breakpoint and hazard rate sequences.""" - ensure_sequence_length(hazard_rates, len(breakpoints) + 1, "hazard_rates") - ensure_positive_sequence(breakpoints, "breakpoints") - ensure_positive_sequence(hazard_rates, "hazard_rates") - if np.any(np.diff(breakpoints) <= 0): - raise ParameterError( - "breakpoints", - breakpoints, - "must be a strictly increasing sequence", - ) - - -def gen_piecewise_exponential( - n: int, - breakpoints: list[float], - hazard_rates: list[float], - betas: list[float] | NDArray[np.float64] | None = None, - n_covariates: int = 2, - covariate_dist: Literal["normal", "uniform", "binary"] = "normal", - covariate_params: dict[str, float] | None = None, - model_cens: Literal["uniform", "exponential"] = "uniform", - cens_par: float = 5.0, - seed: int | None = None, -) -> pd.DataFrame: - """ - Generate survival data using a piecewise exponential distribution. - - Parameters - ---------- - n : int - Number of subjects. - breakpoints : list of float - Time points where hazard rates change. Must be in ascending order. - The first interval is [0, breakpoints[0]), the second is [breakpoints[0], breakpoints[1]), etc. - hazard_rates : list of float - Hazard rates for each interval. Length should be len(breakpoints) + 1. - betas : list or array, optional - Coefficients for covariates. If None, generates random coefficients. - n_covariates : int, default=2 - Number of covariates to generate if betas is None. - covariate_dist : {"normal", "uniform", "binary"}, default="normal" - Distribution to generate covariates from. - covariate_params : dict, optional - Parameters for covariate distribution: - - "normal": {"mean": float, "std": float} - - "uniform": {"low": float, "high": float} - - "binary": {"p": float} - If None, uses defaults based on distribution. - model_cens : {"uniform", "exponential"}, default="uniform" - Censoring mechanism. - cens_par : float, default=5.0 - Parameter for censoring distribution. - seed : int, optional - Random seed for reproducibility. - - Returns - ------- - pd.DataFrame - DataFrame with columns: - - "id": Subject identifier - - "time": Time to event or censoring - - "status": Event indicator (1=event, 0=censored) - - "X0", "X1", ...: Covariates - - Examples - -------- - >>> from gen_surv.piecewise import gen_piecewise_exponential - >>> - >>> # Generate data with 3 intervals (increasing hazard) - >>> df = gen_piecewise_exponential( - ... n=100, - ... breakpoints=[1.0, 3.0], - ... hazard_rates=[0.2, 0.5, 1.0], - ... betas=[0.8, -0.5], - ... seed=42 - ... ) - """ - rng = np.random.default_rng(seed) - - ensure_positive_int(n, "n") - ensure_positive_int(n_covariates, "n_covariates") - ensure_positive(cens_par, "cens_par") - - # Validate inputs - _validate_piecewise_params(breakpoints, hazard_rates) - - ensure_censoring_model(model_cens) - ensure_in_choices(covariate_dist, "covariate_dist", {"normal", "uniform", "binary"}) - covariate_params = set_covariate_params(covariate_dist, covariate_params) - - # Set default betas if not provided - if betas is None: - betas = rng.normal(0, 0.5, size=n_covariates) - else: - ensure_numeric_sequence(betas, "betas") - betas = np.array(betas, dtype=float) - n_covariates = len(betas) - - # Generate covariates - X = generate_covariates(n, n_covariates, covariate_dist, covariate_params, rng) - - # Calculate linear predictor - linear_predictor = X @ betas - - # Generate survival times using piecewise exponential distribution - survival_times = np.zeros(n) - - for i in range(n): - # Adjust hazard rates by the covariate effect - adjusted_hazard_rates = [h * np.exp(linear_predictor[i]) for h in hazard_rates] - - # Generate random uniform between 0 and 1 - u = rng.uniform(0, 1) - - # Calculate survival time using inverse CDF method for piecewise exponential - remaining_time = -np.log(u) # Initial time remaining (for standard exponential) - total_time = 0.0 - - # Start with the first interval [0, breakpoints[0]) - interval_width = breakpoints[0] - hazard = adjusted_hazard_rates[0] - time_to_consume = remaining_time / hazard - - if time_to_consume < interval_width: - # Event occurs in first interval - survival_times[i] = time_to_consume - continue - - # Event occurs after first interval - total_time += interval_width - remaining_time -= hazard * interval_width - - # Go through middle intervals [breakpoints[j-1], breakpoints[j]) - for j in range(1, len(breakpoints)): - interval_width = breakpoints[j] - breakpoints[j - 1] - hazard = adjusted_hazard_rates[j] - time_to_consume = remaining_time / hazard - - if time_to_consume < interval_width: - # Event occurs in this interval - survival_times[i] = total_time + time_to_consume - break - - # Event occurs after this interval - total_time += interval_width - remaining_time -= hazard * interval_width - - # If we've gone through all intervals and still no event, - # use the last hazard rate for the remainder - if remaining_time > 0: - hazard = adjusted_hazard_rates[-1] - survival_times[i] = total_time + remaining_time / hazard - - # Generate censoring times - rfunc = runifcens if model_cens == "uniform" else rexpocens - cens_times = rfunc(n, cens_par, rng) - - # Determine observed time and status - observed_times = np.minimum(survival_times, cens_times) - status = (survival_times <= cens_times).astype(int) - - # Create DataFrame - data = pd.DataFrame({"id": np.arange(n), "time": observed_times, "status": status}) - - # Add covariates - for j in range(n_covariates): - data[f"X{j}"] = X[:, j] - - return data - - -def piecewise_hazard_function( - t: float | NDArray[np.float64], - breakpoints: list[float], - hazard_rates: list[float], -) -> float | NDArray[np.float64]: - """ - Calculate the hazard function value at time t for a piecewise exponential distribution. - - Parameters - ---------- - t : float or array - Time point(s) at which to evaluate the hazard function. - breakpoints : list of float - Time points where hazard rates change. - hazard_rates : list of float - Hazard rates for each interval. - - Returns - ------- - float or array - Hazard function value(s) at time t. - """ - _validate_piecewise_params(breakpoints, hazard_rates) - - # Convert scalar input to array for consistent processing - scalar_input = np.isscalar(t) - t_array = np.atleast_1d(t) - result = np.zeros_like(t_array) - - # Assign hazard rates based on time intervals - result[t_array < 0] = 0 # Hazard is 0 for negative times - - # First interval: [0, breakpoints[0]) - mask = (t_array >= 0) & (t_array < breakpoints[0]) - result[mask] = hazard_rates[0] - - # Middle intervals: [breakpoints[j-1], breakpoints[j]) - for j in range(1, len(breakpoints)): - mask = (t_array >= breakpoints[j - 1]) & (t_array < breakpoints[j]) - result[mask] = hazard_rates[j] - - # Last interval: [breakpoints[-1], infinity) - mask = t_array >= breakpoints[-1] - result[mask] = hazard_rates[-1] - - return result[0] if scalar_input else result - - -def piecewise_survival_function( - t: float | NDArray[np.float64], - breakpoints: list[float], - hazard_rates: list[float], -) -> float | NDArray[np.float64]: - """ - Calculate the survival function at time t for a piecewise exponential distribution. - - Parameters - ---------- - t : float or array - Time point(s) at which to evaluate the survival function. - breakpoints : list of float - Time points where hazard rates change. - hazard_rates : list of float - Hazard rates for each interval. - - Returns - ------- - float or array - Survival function value(s) at time t. - """ - _validate_piecewise_params(breakpoints, hazard_rates) - - # Convert scalar input to array for consistent processing - scalar_input = np.isscalar(t) - t_array = np.atleast_1d(t) - result = np.ones_like(t_array) - - # For each time point, calculate the survival function - for i, ti in enumerate(t_array): - if ti <= 0: - continue # Survival probability is 1 at time 0 or earlier - - cumulative_hazard = 0.0 - - # First interval: [0, min(ti, breakpoints[0])) - first_interval_end = min(ti, breakpoints[0]) if breakpoints else ti - cumulative_hazard += hazard_rates[0] * first_interval_end - - if ti <= breakpoints[0]: - result[i] = np.exp(-cumulative_hazard) - continue - - # Middle intervals: [breakpoints[j-1], min(ti, breakpoints[j])) - for j in range(1, len(breakpoints)): - if ti <= breakpoints[j - 1]: - break - - interval_start = breakpoints[j - 1] - interval_end = min(ti, breakpoints[j]) - interval_width = interval_end - interval_start - - cumulative_hazard += hazard_rates[j] * interval_width - - if ti <= breakpoints[j]: - break - - # Last interval: [breakpoints[-1], ti) - if ti > breakpoints[-1]: - last_interval_width = ti - breakpoints[-1] - cumulative_hazard += hazard_rates[-1] * last_interval_width - - result[i] = np.exp(-cumulative_hazard) - - return result[0] if scalar_input else result diff --git a/gen_surv/sklearn_adapter.py b/gen_surv/sklearn_adapter.py deleted file mode 100644 index de6e180..0000000 --- a/gen_surv/sklearn_adapter.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, Protocol - -import pandas as pd - -from .interface import ModelType, generate -from .validation import ensure_in_choices - - -class BaseEstimatorProto(Protocol): - """Protocol capturing the minimal scikit-learn estimator interface.""" - - def get_params(self, deep: bool = ...) -> dict[str, object]: ... - - def set_params(self, **params: object) -> "BaseEstimatorProto": ... - - -if TYPE_CHECKING: # pragma: no cover - import for type checkers only - from sklearn.base import BaseEstimator as SklearnBase -else: # pragma: no cover - runtime import with fallback - try: - from sklearn.base import BaseEstimator as SklearnBase - except Exception: - - class SklearnBase: # noqa: D401 - simple runtime stub - """Minimal stub if scikit-learn is not installed.""" - - def get_params(self, deep: bool = True) -> dict[str, object]: - return {} - - def set_params(self, **params: object) -> "SklearnBase": - return self - - -class GenSurvDataGenerator(SklearnBase, BaseEstimatorProto): - """Scikit-learn compatible wrapper around :func:`gen_surv.generate`.""" - - def __init__( - self, model: ModelType, return_type: str = "df", **kwargs: object - ) -> None: - ensure_in_choices(return_type, "return_type", {"df", "dict"}) - self.model = model - self.return_type = return_type - self.kwargs = kwargs - - def fit( - self, X: Optional[pd.DataFrame] = None, y: Optional[pd.Series] = None - ) -> "GenSurvDataGenerator": - return self - - def transform( - self, X: Optional[pd.DataFrame] = None - ) -> pd.DataFrame | dict[str, list[object]]: - df = generate(self.model, **self.kwargs) - if self.return_type == "df": - return df - if self.return_type == "dict": - return df.to_dict(orient="list") - raise AssertionError("Unreachable due to validation") - - def fit_transform( - self, - X: Optional[pd.DataFrame] = None, - y: Optional[pd.Series] = None, - **fit_params: object, - ) -> pd.DataFrame | dict[str, list[object]]: - return self.fit(X, y).transform(X) diff --git a/gen_surv/summary.py b/gen_surv/summary.py deleted file mode 100644 index 1150793..0000000 --- a/gen_surv/summary.py +++ /dev/null @@ -1,508 +0,0 @@ -""" -Utilities for summarizing and validating survival datasets. - -This module provides functions to summarize survival data, -check data quality, and identify potential issues. -""" - -from typing import Any - -import pandas as pd - -from .validation import ParameterError - - -def summarize_survival_dataset( - data: pd.DataFrame, - time_col: str = "time", - status_col: str = "status", - id_col: str | None = None, - covariate_cols: list[str] | None = None, - verbose: bool = True, -) -> dict[str, Any]: - """ - Generate a comprehensive summary of a survival dataset. - - Parameters - ---------- - data : pd.DataFrame - DataFrame containing survival data. - time_col : str, default="time" - Name of the column containing time-to-event values. - status_col : str, default="status" - Name of the column containing event indicators (1=event, 0=censored). - id_col : str, optional - Name of the column containing subject identifiers. - covariate_cols : list of str, optional - List of column names to include as covariates in the summary. - If None, all columns except time_col, status_col, and id_col are considered. - verbose : bool, default=True - Whether to print the summary to console. - - Returns - ------- - dict[str, Any] - Dictionary containing all summary statistics. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.summary import summarize_survival_dataset - >>> - >>> # Generate example data - >>> df = generate(model="cphm", n=100, model_cens="uniform", - ... cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> - >>> # Summarize the dataset - >>> summary = summarize_survival_dataset(df) - """ - # Validate input columns - for col in [time_col, status_col]: - if col not in data.columns: - raise ParameterError("column", col, "not found in data") - - if id_col is not None and id_col not in data.columns: - raise ParameterError("id_col", id_col, "not found in data") - - # Determine covariate columns - if covariate_cols is None: - exclude_cols = {time_col, status_col} - if id_col is not None: - exclude_cols.add(id_col) - covariate_cols = [col for col in data.columns if col not in exclude_cols] - else: - missing_cols = [col for col in covariate_cols if col not in data.columns] - if missing_cols: - raise ParameterError("covariate_cols", missing_cols, "not found in data") - - # Basic dataset information - n_subjects = len(data) - if id_col is not None: - n_unique_ids = data[id_col].nunique() - else: - n_unique_ids = n_subjects - - # Event information - n_events = data[status_col].sum() - n_censored = n_subjects - n_events - event_rate = n_events / n_subjects - - # Time statistics - time_min = data[time_col].min() - time_max = data[time_col].max() - time_mean = data[time_col].mean() - time_median = data[time_col].median() - - # Data quality checks - n_missing_time = data[time_col].isna().sum() - n_missing_status = data[status_col].isna().sum() - n_negative_time = (data[time_col] < 0).sum() - n_invalid_status = data[~data[status_col].isin([0, 1])].shape[0] - - # Covariate summaries - covariate_stats = {} - for col in covariate_cols: - col_data = data[col] - is_numeric = pd.api.types.is_numeric_dtype(col_data) - - if is_numeric: - covariate_stats[col] = { - "type": "numeric", - "min": col_data.min(), - "max": col_data.max(), - "mean": col_data.mean(), - "median": col_data.median(), - "std": col_data.std(), - "missing": col_data.isna().sum(), - "unique_values": col_data.nunique(), - } - else: - # Categorical/string - covariate_stats[col] = { - "type": "categorical", - "n_categories": col_data.nunique(), - "top_categories": col_data.value_counts().head(5).to_dict(), - "missing": col_data.isna().sum(), - } - - # Compile the summary - summary = { - "dataset_info": { - "n_subjects": n_subjects, - "n_unique_ids": n_unique_ids, - "n_covariates": len(covariate_cols), - }, - "event_info": { - "n_events": n_events, - "n_censored": n_censored, - "event_rate": event_rate, - }, - "time_info": { - "min": time_min, - "max": time_max, - "mean": time_mean, - "median": time_median, - }, - "data_quality": { - "missing_time": n_missing_time, - "missing_status": n_missing_status, - "negative_time": n_negative_time, - "invalid_status": n_invalid_status, - "overall_quality": ( - "good" - if ( - n_missing_time - + n_missing_status - + n_negative_time - + n_invalid_status - ) - == 0 - else "issues_detected" - ), - }, - "covariates": covariate_stats, - } - - # Print summary if requested - if verbose: - _print_summary(summary, time_col, status_col, id_col, covariate_cols) - - return summary - - -def check_survival_data_quality( - data: pd.DataFrame, - time_col: str = "time", - status_col: str = "status", - id_col: str | None = None, - min_time: float = 0.0, - max_time: float | None = None, - status_values: list[int] | None = None, - fix_issues: bool = False, -) -> tuple[pd.DataFrame, dict[str, Any]]: - """ - Check for common issues in survival data and optionally fix them. - - Parameters - ---------- - data : pd.DataFrame - DataFrame containing survival data. - time_col : str, default="time" - Name of the column containing time-to-event values. - status_col : str, default="status" - Name of the column containing event indicators. - id_col : str, optional - Name of the column containing subject identifiers. - min_time : float, default=0.0 - Minimum acceptable value for time column. - max_time : float, optional - Maximum acceptable value for time column. - status_values : list of int, optional - List of valid status values. Default is [0, 1]. - fix_issues : bool, default=False - Whether to attempt fixing issues (returns a modified DataFrame). - - Returns - ------- - tuple[pd.DataFrame, dict[str, Any]] - Tuple containing (possibly fixed) DataFrame and issues report. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.summary import check_survival_data_quality - >>> - >>> # Generate example data with some issues - >>> df = generate(model="cphm", n=100, model_cens="uniform", - ... cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> # Introduce some issues - >>> df.loc[0, "time"] = np.nan - >>> df.loc[1, "status"] = 2 # Invalid status - >>> - >>> # Check and fix issues - >>> fixed_df, issues = check_survival_data_quality(df, fix_issues=True) - >>> print(issues) - """ - if status_values is None: - status_values = [0, 1] - - # Make a copy to avoid modifying the original - if fix_issues: - data = data.copy() - - # Initialize issues report - issues: dict[str, dict[str, int | None]] = { - "missing_data": {"time": 0, "status": 0, "id": 0 if id_col else None}, - "invalid_values": { - "negative_time": 0, - "excessive_time": 0, - "invalid_status": 0, - }, - "duplicates": {"duplicate_rows": 0, "duplicate_ids": 0 if id_col else None}, - "modifications": {"rows_dropped": 0, "values_fixed": 0}, - } - - # Check for missing values - issues["missing_data"]["time"] = data[time_col].isna().sum() - issues["missing_data"]["status"] = data[status_col].isna().sum() - if id_col: - issues["missing_data"]["id"] = data[id_col].isna().sum() - - # Check for invalid values - issues["invalid_values"]["negative_time"] = (data[time_col] < min_time).sum() - if max_time is not None: - issues["invalid_values"]["excessive_time"] = (data[time_col] > max_time).sum() - issues["invalid_values"]["invalid_status"] = data[ - ~data[status_col].isin(status_values) - ].shape[0] - - # Check for duplicates - issues["duplicates"]["duplicate_rows"] = data.duplicated().sum() - if id_col: - issues["duplicates"]["duplicate_ids"] = data[id_col].duplicated().sum() - - # Fix issues if requested - if fix_issues: - original_rows = len(data) - modified_values = 0 - - # Handle missing values - data = data.dropna(subset=[time_col, status_col]) - - # Handle invalid values - if min_time > 0: - # Set negative or too small times to min_time - mask = data[time_col] < min_time - if mask.any(): - data.loc[mask, time_col] = min_time - modified_values += mask.sum() - - if max_time is not None: - # Cap excessively large times - mask = data[time_col] > max_time - if mask.any(): - data.loc[mask, time_col] = max_time - modified_values += mask.sum() - - # Fix invalid status values - mask = ~data[status_col].isin(status_values) - if mask.any(): - # Default to censored (0) for invalid status - data.loc[mask, status_col] = 0 - modified_values += mask.sum() - - # Remove duplicates - data = data.drop_duplicates() - - # Update modification counts - issues["modifications"]["rows_dropped"] = original_rows - len(data) - issues["modifications"]["values_fixed"] = modified_values - - return data, issues - - -def _print_summary( - summary: dict[str, Any], - time_col: str, - status_col: str, - id_col: str | None, - covariate_cols: list[str], -) -> None: - """ - Print a formatted summary of survival data. - - Parameters - ---------- - summary : dict[str, Any] - Summary dictionary from summarize_survival_dataset. - time_col : str - Name of the time column. - status_col : str - Name of the status column. - id_col : str, optional - Name of the ID column. - covariate_cols : list[str] - List of covariate column names. - """ - print("=" * 60) - print("SURVIVAL DATASET SUMMARY") - print("=" * 60) - - # Dataset info - print("\nDATASET INFORMATION:") - print(f" Subjects: {summary['dataset_info']['n_subjects']}") - if id_col: - print(f" Unique IDs: {summary['dataset_info']['n_unique_ids']}") - print(f" Covariates: {summary['dataset_info']['n_covariates']}") - - # Event info - print("\nEVENT INFORMATION:") - print( - f" Events: {summary['event_info']['n_events']} " - + f"({summary['event_info']['event_rate']:.1%})" - ) - print( - f" Censored: {summary['event_info']['n_censored']} " - + f"({1 - summary['event_info']['event_rate']:.1%})" - ) - - # Time info - print(f"\nTIME VARIABLE ({time_col}):") - print( - f" Range: {summary['time_info']['min']:.2f} to {summary['time_info']['max']:.2f}" - ) - print(f" Mean: {summary['time_info']['mean']:.2f}") - print(f" Median: {summary['time_info']['median']:.2f}") - - # Data quality - print("\nDATA QUALITY:") - quality_issues = ( - summary["data_quality"]["missing_time"] - + summary["data_quality"]["missing_status"] - + summary["data_quality"]["negative_time"] - + summary["data_quality"]["invalid_status"] - ) - - if quality_issues == 0: - print(" βœ“ No issues detected") - else: - print(" βœ— Issues detected:") - if summary["data_quality"]["missing_time"] > 0: - print( - f" - Missing time values: {summary['data_quality']['missing_time']}" - ) - if summary["data_quality"]["missing_status"] > 0: - print( - f" - Missing status values: {summary['data_quality']['missing_status']}" - ) - if summary["data_quality"]["negative_time"] > 0: - print( - f" - Negative time values: {summary['data_quality']['negative_time']}" - ) - if summary["data_quality"]["invalid_status"] > 0: - print( - f" - Invalid status values: {summary['data_quality']['invalid_status']}" - ) - - # Covariates - print("\nCOVARIATES:") - if not covariate_cols: - print(" No covariates found") - else: - for col, stats in summary["covariates"].items(): - print(f" {col}:") - if stats["type"] == "numeric": - print(" Type: Numeric") - print(f" Range: {stats['min']:.2f} to {stats['max']:.2f}") - print(f" Mean: {stats['mean']:.2f}") - print(f" Missing: {stats['missing']}") - else: - print(" Type: Categorical") - print(f" Categories: {stats['n_categories']}") - print(f" Missing: {stats['missing']}") - - print("\n" + "=" * 60) - - -def compare_survival_datasets( - datasets: dict[str, pd.DataFrame], - time_col: str = "time", - status_col: str = "status", - covariate_cols: list[str] | None = None, -) -> pd.DataFrame: - """ - Compare multiple survival datasets and summarize their differences. - - Parameters - ---------- - datasets : dict[str, pd.DataFrame] - Dictionary mapping dataset names to DataFrames. - time_col : str, default="time" - Name of the time column in each dataset. - status_col : str, default="status" - Name of the status column in each dataset. - covariate_cols : List[str], optional - List of covariate columns to compare. If None, compares all common columns. - - Returns - ------- - pd.DataFrame - Comparison table with datasets as columns and metrics as rows. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.summary import compare_survival_datasets - >>> - >>> # Generate datasets with different parameters - >>> datasets = { - ... "CPHM": generate(model="cphm", n=100, model_cens="uniform", - ... cens_par=1.0, beta=0.5, covariate_range=2.0), - ... "Weibull AFT": generate(model="aft_weibull", n=100, beta=[0.5], - ... shape=1.5, scale=1.0, model_cens="uniform", cens_par=1.0) - ... } - >>> - >>> # Compare datasets - >>> comparison = compare_survival_datasets(datasets) - >>> print(comparison) - """ - if not datasets: - raise ParameterError("datasets", datasets, "at least one dataset is required") - - # Find common columns if covariate_cols not specified - if covariate_cols is None: - all_columns = [set(df.columns) for df in datasets.values()] - common_columns = set.intersection(*all_columns) - common_columns -= {time_col, status_col} # Remove time and status - covariate_cols = sorted(list(common_columns)) - - # Calculate summaries for each dataset - summaries = {} - for name, data in datasets.items(): - summaries[name] = summarize_survival_dataset( - data, time_col, status_col, covariate_cols=covariate_cols, verbose=False - ) - - # Construct the comparison DataFrame - comparison_data = {} - - # Dataset info - comparison_data["n_subjects"] = { - name: summary["dataset_info"]["n_subjects"] - for name, summary in summaries.items() - } - comparison_data["n_events"] = { - name: summary["event_info"]["n_events"] for name, summary in summaries.items() - } - comparison_data["event_rate"] = { - name: summary["event_info"]["event_rate"] for name, summary in summaries.items() - } - - # Time info - comparison_data["time_min"] = { - name: summary["time_info"]["min"] for name, summary in summaries.items() - } - comparison_data["time_max"] = { - name: summary["time_info"]["max"] for name, summary in summaries.items() - } - comparison_data["time_mean"] = { - name: summary["time_info"]["mean"] for name, summary in summaries.items() - } - comparison_data["time_median"] = { - name: summary["time_info"]["median"] for name, summary in summaries.items() - } - - # Covariate info (means for numeric) - for col in covariate_cols: - for name, summary in summaries.items(): - if col in summary["covariates"]: - col_stats = summary["covariates"][col] - if col_stats["type"] == "numeric": - if f"{col}_mean" not in comparison_data: - comparison_data[f"{col}_mean"] = {} - comparison_data[f"{col}_mean"][name] = col_stats["mean"] - - # Create the DataFrame - comparison_df = pd.DataFrame(comparison_data).T - - return comparison_df diff --git a/gen_surv/tdcm.py b/gen_surv/tdcm.py index a92e24c..b349137 100644 --- a/gen_surv/tdcm.py +++ b/gen_surv/tdcm.py @@ -1,74 +1,55 @@ -from typing import Sequence - import numpy as np import pandas as pd -from numpy.typing import NDArray - +from gen_surv.validate import validate_gen_tdcm_inputs from gen_surv.bivariate import sample_bivariate_distribution -from gen_surv.censoring import CensoringFunc, rexpocens, runifcens -from gen_surv.validation import validate_gen_tdcm_inputs - +from gen_surv.censoring import runifcens, rexpocens -def generate_censored_observations( - n: int, - dist_par: Sequence[float], - model_cens: str, - cens_par: float, - beta: Sequence[float], - lam: float, - b: NDArray[np.float64], -) -> NDArray[np.float64]: +def generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b): """ Generate censored TDCM observations. Parameters: - - - n (int): Number of individuals - - dist_par (list): Not directly used here (kept for API compatibility) - - model_cens (str): "uniform" or "exponential" - - cens_par (float): Parameter for the censoring model - - beta (list): Length-2 list of regression coefficients - - lam (float): Rate parameter - - b (np.ndarray): Covariate matrix with 2 columns [., z1] + - n (int): Number of individuals + - dist_par (list): Not directly used here (kept for API compatibility) + - model_cens (str): "uniform" or "exponential" + - cens_par (float): Parameter for the censoring model + - beta (list): Length-2 list of regression coefficients + - lam (float): Rate parameter + - b (np.ndarray): Covariate matrix with 2 columns [., z1] Returns: - - np.ndarray: Shape (n, 6) with columns: - [id, start, stop, status, covariate1 (z1), covariate2 (z2)] + - np.ndarray: Shape (n, 6) with columns: + [id, start, stop, status, covariate1 (z1), covariate2 (z2)] """ - rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens + rfunc = runifcens if model_cens == "uniform" else rexpocens + observations = np.zeros((n, 6)) + + for k in range(n): + z1 = b[k, 1] + c = rfunc(1, cens_par)[0] + u = np.random.uniform() - z1 = b[:, 1] - x = lam * b[:, 0] * np.exp(beta[0] * z1) - u = np.random.uniform(size=n) - c = rfunc(n, cens_par) + # Determine path based on u threshold + threshold = 1 - np.exp(-lam * b[k, 0] * np.exp(beta[0] * z1)) + if u < threshold: + t = -np.log(1 - u) / (lam * np.exp(beta[0] * z1)) + z2 = 0 + else: + t = ( + -np.log(1 - u) + + lam * b[k, 0] * np.exp(beta[0] * z1) * (1 - np.exp(beta[1])) + ) / (lam * np.exp(beta[0] * z1 + beta[1])) + z2 = 1 - threshold = 1 - np.exp(-x) - exp_b0_z1 = np.exp(beta[0] * z1) - log_term = -np.log(1 - u) - t1 = log_term / (lam * exp_b0_z1) - t2 = (log_term + x * (1 - np.exp(beta[1]))) / (lam * np.exp(beta[0] * z1 + beta[1])) - mask = u < threshold - t = np.where(mask, t1, t2) - z2 = (~mask).astype(float) + time = min(t, c) + status = int(t <= c) - time = np.minimum(t, c) - status = (t <= c).astype(float) + observations[k] = [k + 1, 0, time, status, z1, z2] - ids = np.arange(1, n + 1, dtype=float) - zeros = np.zeros(n, dtype=float) - return np.column_stack((ids, zeros, time, status, z1, z2)) + return observations -def gen_tdcm( - n: int, - dist: str, - corr: float, - dist_par: Sequence[float], - model_cens: str, - cens_par: float, - beta: Sequence[float], - lam: float, -) -> pd.DataFrame: +def gen_tdcm(n, dist, corr, dist_par, model_cens, cens_par, beta, lam): """ Generate TDCM (Time-Dependent Covariate Model) survival data. @@ -90,10 +71,6 @@ def gen_tdcm( # Generate covariate matrix from bivariate distribution b = sample_bivariate_distribution(n, dist, corr, dist_par) - data = generate_censored_observations( - n, dist_par, model_cens, cens_par, beta, lam, b - ) + data = generate_censored_observations(n, dist_par, model_cens, cens_par, beta, lam, b) - return pd.DataFrame( - data, columns=["id", "start", "stop", "status", "covariate", "tdcov"] - ) + return pd.DataFrame(data, columns=["id", "start", "stop", "status", "covariate", "tdcov"]) diff --git a/gen_surv/thmm.py b/gen_surv/thmm.py index 4a832c0..22feebf 100644 --- a/gen_surv/thmm.py +++ b/gen_surv/thmm.py @@ -1,26 +1,9 @@ -from typing import Sequence, TypedDict - import numpy as np import pandas as pd +from gen_surv.validate import validate_gen_thmm_inputs +from gen_surv.censoring import runifcens, rexpocens -from gen_surv.censoring import CensoringFunc, rexpocens, runifcens -from gen_surv.validation import validate_gen_thmm_inputs - - -class TransitionTimes(TypedDict): - c: float - t12: float - t13: float - t23: float - - -def calculate_transitions( - z1: float, - cens_par: float, - beta: Sequence[float], - rate: Sequence[float], - rfunc: CensoringFunc, -) -> TransitionTimes: +def calculate_transitions(z1: float, cens_par: float, beta: list, rate: list, rfunc) -> dict: """ Calculate transition and censoring times for THMM. @@ -46,14 +29,7 @@ def calculate_transitions( return {"c": c, "t12": t12, "t13": t13, "t23": t23} -def gen_thmm( - n: int, - model_cens: str, - cens_par: float, - beta: Sequence[float], - covariate_range: float, - rate: Sequence[float], -) -> pd.DataFrame: +def gen_thmm(n, model_cens, cens_par, beta, covar, rate): """ Generate THMM (Time-Homogeneous Markov Model) survival data. @@ -62,18 +38,18 @@ def gen_thmm( - model_cens (str): "uniform" or "exponential". - cens_par (float): Censoring parameter. - beta (list): Length-3 regression coefficients. - - covariate_range (float): Upper bound for the covariate values. + - covar (float): Covariate upper bound. - rate (list): Length-3 transition rates. Returns: - - pd.DataFrame: Columns = ["id", "time", "state", "X0"] + - pd.DataFrame: Columns = ["id", "time", "state", "covariate"] """ - validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covariate_range, rate) - rfunc: CensoringFunc = runifcens if model_cens == "uniform" else rexpocens + validate_gen_thmm_inputs(n, model_cens, cens_par, beta, covar, rate) + rfunc = runifcens if model_cens == "uniform" else rexpocens records = [] for k in range(n): - z1 = np.random.uniform(0, covariate_range) + z1 = np.random.uniform(0, covar) trans = calculate_transitions(z1, cens_par, beta, rate, rfunc) t12, t13, c = trans["t12"], trans["t13"], trans["c"] @@ -87,4 +63,4 @@ def gen_thmm( records.append([k + 1, time, state, z1]) - return pd.DataFrame(records, columns=["id", "time", "state", "X0"]) + return pd.DataFrame(records, columns=["id", "time", "state", "covariate"]) diff --git a/gen_surv/validate.py b/gen_surv/validate.py index f15f291..427f5c3 100644 --- a/gen_surv/validate.py +++ b/gen_surv/validate.py @@ -1,6 +1,193 @@ -"""Compatibility wrapper for validation utilities. +def validate_gen_cphm_inputs(n: int, model_cens: str, cens_par: float, covar: float): + """ + Validates input parameters for CPHM data generation. -This module re-exports symbols from :mod:`gen_surv.validation`. -""" + Parameters: + - n (int): Number of data points to generate. + - model_cens (str): Censoring model, must be "uniform" or "exponential". + - cens_par (float): Parameter for the censoring model, must be > 0. + - covar (float): Covariate value, must be > 0. -from .validation import * # noqa: F401,F403 + Raises: + - ValueError: If any input is invalid. + """ + if n <= 0: + raise ValueError("Argument 'n' must be greater than 0") + if model_cens not in {"uniform", "exponential"}: + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + if cens_par <= 0: + raise ValueError("Argument 'cens_par' must be greater than 0") + if covar <= 0: + raise ValueError("Argument 'covar' must be greater than 0") + + +def validate_gen_cmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): + """ + Validate inputs for generating CMM (Continuous-Time Markov Model) data. + + Parameters: + - n (int): Number of individuals. + - model_cens (str): Censoring model, must be "uniform" or "exponential". + - cens_par (float): Parameter for censoring distribution, must be > 0. + - beta (list): Regression coefficients, must have length 3. + - covar (float): Covariate value, must be > 0. + - rate (list): Transition rates, must have length 6. + + Raises: + - ValueError: If any parameter is invalid. + """ + if n <= 0: + raise ValueError("Argument 'n' must be greater than 0") + if model_cens not in {"uniform", "exponential"}: + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + if cens_par <= 0: + raise ValueError("Argument 'cens_par' must be greater than 0") + if len(beta) != 3: + raise ValueError("Argument 'beta' must be a list of length 3") + if covar <= 0: + raise ValueError("Argument 'covar' must be greater than 0") + if len(rate) != 6: + raise ValueError("Argument 'rate' must be a list of length 6") + + +def validate_gen_tdcm_inputs(n: int, dist: str, corr: float, dist_par: list, + model_cens: str, cens_par: float, beta: list, lam: float): + """ + Validate inputs for generating TDCM (Time-Dependent Covariate Model) data. + + Parameters: + - n (int): Number of observations. + - dist (str): "weibull" or "exponential". + - corr (float): Correlation coefficient. + - dist_par (list): Distribution parameters. + - model_cens (str): "uniform" or "exponential". + - cens_par (float): Censoring parameter. + - beta (list): Length-2 list of regression coefficients. + - lam (float): Lambda parameter, must be > 0. + + Raises: + - ValueError: For any invalid input. + """ + if n <= 0: + raise ValueError("Argument 'n' must be greater than 0") + + if dist not in {"weibull", "exponential"}: + raise ValueError( + "Argument 'dist' must be one of 'weibull' or 'exponential'") + + if dist == "weibull": + if not (0 < corr <= 1): + raise ValueError("With dist='weibull', 'corr' must be in (0,1]") + if len(dist_par) != 4 or any(p <= 0 for p in dist_par): + raise ValueError( + "With dist='weibull', 'dist_par' must be a positive list of length 4") + + if dist == "exponential": + if not (-1 <= corr <= 1): + raise ValueError( + "With dist='exponential', 'corr' must be in [-1,1]") + if len(dist_par) != 2 or any(p <= 0 for p in dist_par): + raise ValueError( + "With dist='exponential', 'dist_par' must be a positive list of length 2") + + if model_cens not in {"uniform", "exponential"}: + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + + if cens_par <= 0: + raise ValueError("Argument 'cens_par' must be greater than 0") + + if not isinstance(beta, list) or len(beta) != 3: + raise ValueError("Argument 'beta' must be a list of length 3") + + if lam <= 0: + raise ValueError("Argument 'lambda' must be greater than 0") + + +def validate_gen_thmm_inputs(n: int, model_cens: str, cens_par: float, beta: list, covar: float, rate: list): + """ + Validate inputs for generating THMM (Time-Homogeneous Markov Model) data. + + Parameters: + - n (int): Number of samples, must be > 0. + - model_cens (str): Must be "uniform" or "exponential". + - cens_par (float): Must be > 0. + - beta (list): List of length 3 (regression coefficients). + - covar (float): Positive covariate value. + - rate (list): List of length 3 (transition rates). + + Raises: + - ValueError if any input is invalid. + """ + if not isinstance(n, int) or n <= 0: + raise ValueError("Argument 'n' must be a positive integer.") + + if model_cens not in {"uniform", "exponential"}: + raise ValueError( + "Argument 'model_cens' must be one of 'uniform' or 'exponential'") + + if not isinstance(cens_par, (int, float)) or cens_par <= 0: + raise ValueError("Argument 'cens_par' must be a positive number.") + + if not isinstance(beta, list) or len(beta) != 3: + raise ValueError("Argument 'beta' must be a list of length 3.") + + if not isinstance(covar, (int, float)) or covar <= 0: + raise ValueError("Argument 'covar' must be greater than 0.") + + if not isinstance(rate, list) or len(rate) != 3: + raise ValueError("Argument 'rate' must be a list of length 3.") + + +def validate_dg_biv_inputs(n: int, dist: str, corr: float, dist_par: list): + """ + Validate inputs for the sample_bivariate_distribution function. + + Parameters: + - n (int): Number of samples to generate. + - dist (str): Must be "weibull" or "exponential". + - corr (float): Must be between -1 and 1. + - dist_par (list): Must contain positive values, and correct length for the distribution. + + Raises: + - ValueError if any input is invalid. + """ + if not isinstance(n, int) or n <= 0: + raise ValueError("Argument 'n' must be a positive integer.") + + if dist not in {"weibull", "exponential"}: + raise ValueError("Argument 'dist' must be one of 'weibull' or 'exponential'.") + + if not isinstance(corr, (int, float)) or not (-1 < corr < 1): + raise ValueError("Argument 'corr' must be a numeric value between -1 and 1.") + + if not isinstance(dist_par, list) or len(dist_par) == 0: + raise ValueError("Argument 'dist_par' must be a non-empty list of positive values.") + + if any(p <= 0 for p in dist_par): + raise ValueError("All elements in 'dist_par' must be greater than 0.") + + if dist == "exponential" and len(dist_par) != 2: + raise ValueError("Exponential distribution requires exactly 2 positive parameters.") + + if dist == "weibull" and len(dist_par) != 4: + raise ValueError("Weibull distribution requires exactly 4 positive parameters.") + + +def validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par): + if not isinstance(n, int) or n <= 0: + raise ValueError("n must be a positive integer") + + if not isinstance(beta, (list, tuple)) or not all(isinstance(b, (int, float)) for b in beta): + raise ValueError("beta must be a list of numbers") + + if not isinstance(sigma, (int, float)) or sigma <= 0: + raise ValueError("sigma must be a positive number") + + if model_cens not in ("uniform", "exponential"): + raise ValueError("model_cens must be 'uniform' or 'exponential'") + + if not isinstance(cens_par, (int, float)) or cens_par <= 0: + raise ValueError("cens_par must be a positive number") diff --git a/gen_surv/validation.py b/gen_surv/validation.py deleted file mode 100644 index fad569a..0000000 --- a/gen_surv/validation.py +++ /dev/null @@ -1,415 +0,0 @@ -"""Input validation utilities. - -This module unifies the low-level validation helpers and the higher-level -checks used by the data generators. -""" - -from __future__ import annotations - -from collections.abc import Sequence -from numbers import Integral, Real -from typing import Any, Iterable - -import numpy as np -from numpy.typing import NDArray - - -class ValidationError(ValueError): - """Base class for input validation errors.""" - - -class PositiveIntegerError(ValidationError): - """Raised when a value expected to be a positive integer is invalid.""" - - def __init__(self, name: str, value: Any) -> None: - super().__init__( - f"Argument '{name}' must be a positive integer; got {value!r} of type {type(value).__name__}" - ) - - -class PositiveValueError(ValidationError): - """Raised when a value expected to be positive is invalid.""" - - def __init__(self, name: str, value: Any) -> None: - super().__init__( - f"Argument '{name}' must be greater than 0; got {value!r} of type {type(value).__name__}" - ) - - -class ChoiceError(ValidationError): - """Raised when a value is not among an allowed set of choices.""" - - def __init__(self, name: str, value: Any, choices: Iterable[str]) -> None: - choices_str = "', '".join(sorted(choices)) - super().__init__( - f"Argument '{name}' must be one of '{choices_str}'; got {value!r} of type {type(value).__name__}" - ) - - -class LengthError(ValidationError): - """Raised when a sequence does not have the expected length.""" - - def __init__(self, name: str, actual: int, expected: int) -> None: - super().__init__( - f"Argument '{name}' must be a sequence of length {expected}; got length {actual}" - ) - - -class NumericSequenceError(ValidationError): - """Raised when a sequence contains non-numeric elements.""" - - def __init__(self, name: str, value: Any, index: int | None = None) -> None: - if index is None: - super().__init__(f"All elements in '{name}' must be numeric; got {value!r}") - else: - super().__init__( - f"All elements in '{name}' must be numeric; found {value!r} of type {type(value).__name__} at index {index}" - ) - - -class PositiveSequenceError(ValidationError): - """Raised when a sequence contains non-positive elements.""" - - def __init__(self, name: str, value: Any, index: int) -> None: - super().__init__( - f"All elements in '{name}' must be greater than 0; found {value!r} at index {index}" - ) - - -class ListOfListsError(ValidationError): - """Raised when a value is not a list of lists.""" - - def __init__(self, name: str, value: Any) -> None: - super().__init__( - f"Argument '{name}' must be a list of lists; got {value!r} of type {type(value).__name__}" - ) - - -class ParameterError(ValidationError): - """Raised when a parameter falls outside its allowed range.""" - - def __init__(self, name: str, value: Any, constraint: str) -> None: - super().__init__( - f"Invalid value for '{name}': {value!r} (type {type(value).__name__}). {constraint}" - ) - - -_ALLOWED_CENSORING = {"uniform", "exponential"} - - -def ensure_positive_int(value: int, name: str) -> None: - """Ensure ``value`` is a positive integer.""" - if not isinstance(value, Integral) or isinstance(value, bool) or value <= 0: - raise PositiveIntegerError(name, value) - - -def ensure_positive(value: float | int, name: str) -> None: - """Ensure ``value`` is a positive number.""" - if not isinstance(value, Real) or isinstance(value, bool) or value <= 0: - raise PositiveValueError(name, value) - - -def ensure_probability(value: float | int, name: str) -> None: - """Ensure ``value`` lies in the closed interval [0, 1].""" - if ( - not isinstance(value, Real) - or isinstance(value, bool) - or not (0 <= float(value) <= 1) - ): - raise ParameterError(name, value, "must be between 0 and 1") - - -def ensure_in_choices(value: str, name: str, choices: Iterable[str]) -> None: - """Ensure ``value`` is one of the allowed options. - - Parameters - ---------- - value: - Value provided by the user. - name: - Name of the argument being validated. Used in error messages. - choices: - Iterable of valid string options. - - Raises - ------ - ChoiceError - If ``value`` is not present in ``choices``. - """ - if value not in choices: - raise ChoiceError(name, value, choices) - - -def ensure_sequence_length(seq: Sequence[Any], length: int, name: str) -> None: - """Ensure a sequence has an expected number of elements. - - Parameters - ---------- - seq: - Sequence-like object (e.g., ``list`` or ``tuple``). - length: - Required number of elements in ``seq``. - name: - Parameter name for error reporting. - - Raises - ------ - LengthError - If ``seq`` does not contain exactly ``length`` elements. - """ - if len(seq) != length: - raise LengthError(name, len(seq), length) - - -def _to_float_array(seq: Sequence[Any], name: str) -> NDArray[np.float64]: - """Convert ``seq`` to a NumPy float64 array or raise an error.""" - try: - arr = np.asarray(seq, dtype=float) - except (TypeError, ValueError) as exc: - for idx, val in enumerate(seq): - if isinstance(val, (bool, np.bool_)) or not isinstance(val, (int, float)): - raise NumericSequenceError(name, val, idx) from exc - raise NumericSequenceError(name, seq) from exc - - for idx, val in enumerate(seq): - if isinstance(val, (bool, np.bool_)): - raise NumericSequenceError(name, val, idx) - - return arr - - -def ensure_numeric_sequence(seq: Sequence[Any], name: str) -> None: - """Validate that a sequence consists solely of numbers. - - Parameters - ---------- - seq: - Sequence whose elements should all be ``int`` or ``float``. - name: - Parameter name for error reporting. - - Raises - ------ - NumericSequenceError - If any element cannot be interpreted as a numeric value. - """ - _to_float_array(seq, name) - - -def ensure_positive_sequence(seq: Sequence[float], name: str) -> None: - """Validate that a sequence contains only positive numbers. - - Parameters - ---------- - seq: - Sequence of numeric values. - name: - Parameter name for error reporting. - - Raises - ------ - PositiveSequenceError - If any element is less than or equal to zero. The offending value and - its index are reported in the error message. - """ - arr = _to_float_array(seq, name) - nonpos = np.where((arr <= 0) | ~np.isfinite(arr))[0] - if nonpos.size: - idx = int(nonpos[0]) - raise PositiveSequenceError(name, seq[idx], idx) - - -def ensure_censoring_model(model_cens: str) -> None: - """Validate that the censoring model is supported. - - Parameters - ---------- - model_cens: - Censoring model name provided by the user. - - Raises - ------ - ChoiceError - If ``model_cens`` is not one of ``"uniform"`` or ``"exponential"``. - """ - ensure_in_choices(model_cens, "model_cens", _ALLOWED_CENSORING) - - -# Generator-specific validation helpers - -_BETA_LEN = 3 -_CMM_RATE_LEN = 6 -_THMM_RATE_LEN = 3 -_WEIBULL_DIST_PAR_LEN = 4 -_EXP_DIST_PAR_LEN = 2 - - -def _validate_base(n: int, model_cens: str, cens_par: float) -> None: - """Common checks for sample size and censoring model.""" - ensure_positive_int(n, "n") - ensure_censoring_model(model_cens) - ensure_positive(cens_par, "cens_par") - - -def _validate_beta(beta: Sequence[float]) -> None: - """Ensure beta is a numeric sequence of length three.""" - ensure_sequence_length(beta, _BETA_LEN, "beta") - ensure_numeric_sequence(beta, "beta") - - -def _validate_aft_common( - n: int, beta: Sequence[float], model_cens: str, cens_par: float -) -> None: - """Shared validation logic for AFT generators.""" - _validate_base(n, model_cens, cens_par) - ensure_numeric_sequence(beta, "beta") - - -def validate_gen_cphm_inputs( - n: int, model_cens: str, cens_par: float, covariate_range: float -) -> None: - """Validate input parameters for CPHM data generation.""" - _validate_base(n, model_cens, cens_par) - ensure_positive(covariate_range, "covariate_range") - - -def validate_gen_cmm_inputs( - n: int, - model_cens: str, - cens_par: float, - beta: Sequence[float], - covariate_range: float, - rate: Sequence[float], -) -> None: - """Validate inputs for generating CMM (Continuous-Time Markov Model) data.""" - _validate_base(n, model_cens, cens_par) - _validate_beta(beta) - ensure_positive(covariate_range, "covariate_range") - ensure_sequence_length(rate, _CMM_RATE_LEN, "rate") - - -def validate_gen_tdcm_inputs( - n: int, - dist: str, - corr: float, - dist_par: Sequence[float], - model_cens: str, - cens_par: float, - beta: Sequence[float], - lam: float, -) -> None: - """Validate inputs for generating TDCM (Time-Dependent Covariate Model) data.""" - _validate_base(n, model_cens, cens_par) - ensure_in_choices(dist, "dist", {"weibull", "exponential"}) - - if dist == "weibull": - if not (0 < corr <= 1): - raise ParameterError("corr", corr, "with dist='weibull' must be in (0,1]") - ensure_sequence_length(dist_par, _WEIBULL_DIST_PAR_LEN, "dist_par") - ensure_positive_sequence(dist_par, "dist_par") - - if dist == "exponential": - if not (-1 <= corr <= 1): - raise ParameterError( - "corr", corr, "with dist='exponential' must be in [-1,1]" - ) - ensure_sequence_length(dist_par, _EXP_DIST_PAR_LEN, "dist_par") - ensure_positive_sequence(dist_par, "dist_par") - - _validate_beta(beta) - ensure_positive(lam, "lambda") - - -def validate_gen_thmm_inputs( - n: int, - model_cens: str, - cens_par: float, - beta: Sequence[float], - covariate_range: float, - rate: Sequence[float], -) -> None: - """Validate inputs for generating THMM (Time-Homogeneous Markov Model) data.""" - _validate_base(n, model_cens, cens_par) - _validate_beta(beta) - ensure_positive(covariate_range, "covariate_range") - ensure_sequence_length(rate, _THMM_RATE_LEN, "rate") - - -def validate_dg_biv_inputs( - n: int, dist: str, corr: float, dist_par: Sequence[float] -) -> None: - """Validate inputs for the :func:`sample_bivariate_distribution` helper.""" - ensure_positive_int(n, "n") - ensure_in_choices(dist, "dist", {"weibull", "exponential"}) - - if not isinstance(corr, (int, float)) or not (-1 < corr < 1): - raise ParameterError("corr", corr, "must be a numeric value between -1 and 1") - - ensure_positive_sequence(dist_par, "dist_par") - if dist == "exponential": - ensure_sequence_length(dist_par, _EXP_DIST_PAR_LEN, "dist_par") - if dist == "weibull": - ensure_sequence_length(dist_par, _WEIBULL_DIST_PAR_LEN, "dist_par") - - -def validate_gen_aft_log_normal_inputs( - n: int, - beta: Sequence[float], - sigma: float, - model_cens: str, - cens_par: float, -) -> None: - """Validate parameters for the log-normal AFT generator.""" - _validate_aft_common(n, beta, model_cens, cens_par) - ensure_positive(sigma, "sigma") - - -def validate_gen_aft_weibull_inputs( - n: int, - beta: Sequence[float], - shape: float, - scale: float, - model_cens: str, - cens_par: float, -) -> None: - """Validate parameters for the Weibull AFT generator.""" - _validate_aft_common(n, beta, model_cens, cens_par) - ensure_positive(shape, "shape") - ensure_positive(scale, "scale") - - -def validate_gen_aft_log_logistic_inputs( - n: int, - beta: Sequence[float], - shape: float, - scale: float, - model_cens: str, - cens_par: float, -) -> None: - """Validate parameters for the log-logistic AFT generator.""" - _validate_aft_common(n, beta, model_cens, cens_par) - ensure_positive(shape, "shape") - ensure_positive(scale, "scale") - - -def validate_competing_risks_inputs( - n: int, - n_risks: int, - baseline_hazards: Sequence[float] | None, - betas: Sequence[Sequence[float]] | None, - model_cens: str, - cens_par: float, -) -> None: - """Validate parameters for competing risks data generation.""" - _validate_base(n, model_cens, cens_par) - ensure_positive_int(n_risks, "n_risks") - - if baseline_hazards is not None: - ensure_sequence_length(baseline_hazards, n_risks, "baseline_hazards") - ensure_positive_sequence(baseline_hazards, "baseline_hazards") - - if betas is not None: - if not isinstance(betas, list) or any(not isinstance(b, list) for b in betas): - raise ListOfListsError("betas", betas) - for b in betas: - ensure_numeric_sequence(b, "betas") diff --git a/gen_surv/visualization.py b/gen_surv/visualization.py deleted file mode 100644 index 3a80c11..0000000 --- a/gen_surv/visualization.py +++ /dev/null @@ -1,375 +0,0 @@ -""" -Visualization utilities for survival data. - -This module provides functions to visualize survival data generated by -gen_surv, -including Kaplan-Meier survival curves and other commonly used plots in -survival analysis. -""" - -import matplotlib.pyplot as plt -import pandas as pd -from matplotlib.axes import Axes -from matplotlib.figure import Figure - - -def plot_survival_curve( - data: pd.DataFrame, - time_col: str = "time", - status_col: str = "status", - group_col: str | None = None, - confidence_intervals: bool = True, - title: str = "Kaplan-Meier Survival Curve", - figsize: tuple[float, float] = (10, 6), - ci_alpha: float = 0.2, -) -> tuple[Figure, Axes]: - """ - Plot Kaplan-Meier survival curves from simulated data. - - Parameters - ---------- - data : pd.DataFrame - DataFrame containing the survival data. - time_col : str, default="time" - Name of the column containing event/censoring times. - status_col : str, default="status" - Name of the column containing event indicators (1=event, 0=censored). - group_col : str, optional - Name of the column to use for stratification (creates separate curves). - confidence_intervals : bool, default=True - Whether to display confidence intervals around the survival curves. - title : str, default="Kaplan-Meier Survival Curve" - Plot title. - figsize : tuple, default=(10, 6) - Figure size (width, height) in inches. - ci_alpha : float, default=0.2 - Transparency level for confidence interval bands. - - Returns - ------- - fig : Figure - Matplotlib figure object. - ax : Axes - Matplotlib axes object. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.visualization import plot_survival_curve - >>> - >>> # Generate data - >>> df = generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> - >>> # Create a categorical group based on covariate - >>> df["group"] = pd.cut(df["covariate"], bins=2, labels=["Low", "High"]) - >>> - >>> # Plot survival curves by group - >>> fig, ax = plot_survival_curve(df, group_col="group") - >>> plt.show() - """ - # Import lifelines here to avoid making it a hard dependency - try: - from lifelines import KaplanMeierFitter - from lifelines.plotting import add_at_risk_counts - except ImportError as exc: - raise ImportError( - "This function requires the lifelines package. " - "Install it with: pip install lifelines" - ) from exc - - fig, ax = plt.subplots(figsize=figsize) - - # Create separate KM curves for each group (if specified) - if group_col is not None: - groups = data[group_col].unique() - cmap = plt.get_cmap("tab10") - colors = [cmap(i) for i in range(len(groups))] - - for i, group in enumerate(groups): - mask = data[group_col] == group - group_data = data[mask] - - kmf = KaplanMeierFitter() - kmf.fit( - group_data[time_col], - group_data[status_col], - label=f"{group_col}={group}", - ) - - kmf.plot_survival_function( - ax=ax, ci_show=confidence_intervals, color=colors[i], ci_alpha=ci_alpha - ) - - # Add at-risk counts below the plot - add_at_risk_counts(kmf, ax=ax) - else: - # Single KM curve for all data - kmf = KaplanMeierFitter() - kmf.fit(data[time_col], data[status_col]) - - kmf.plot_survival_function( - ax=ax, ci_show=confidence_intervals, ci_alpha=ci_alpha - ) - - # Add at-risk counts below the plot - add_at_risk_counts(kmf, ax=ax) - - # Customize plot appearance - ax.set_title(title) - ax.set_xlabel("Time") - ax.set_ylabel("Survival Probability") - ax.grid(alpha=0.3) - ax.set_ylim(0, 1.05) - - plt.tight_layout() - return fig, ax - - -def plot_hazard_comparison( - models: dict[str, pd.DataFrame], - time_col: str = "time", - status_col: str = "status", - title: str = "Hazard Function Comparison", - figsize: tuple[float, float] = (10, 6), - bandwidth: float = 0.5, -) -> tuple[Figure, Axes]: - """ - Compare hazard functions from multiple generated datasets. - - Parameters - ---------- - models : dict - Dictionary mapping model names to their respective DataFrames. - time_col : str, default="time" - Name of the column containing event/censoring times. - status_col : str, default="status" - Name of the column containing event indicators (1=event, 0=censored). - title : str, default="Hazard Function Comparison" - Plot title. - figsize : tuple, default=(10, 6) - Figure size (width, height) in inches. - bandwidth : float, default=0.5 - Bandwidth parameter for kernel density estimation of the hazard function. - - Returns - ------- - fig : Figure - Matplotlib figure object. - ax : Axes - Matplotlib axes object. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.visualization import plot_hazard_comparison - >>> - >>> # Generate data from multiple models - >>> models = { - >>> "CPHM": generate(model="cphm", n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0), - >>> "AFT Weibull": generate(model="aft_weibull", n=100, beta=[0.5], shape=1.5, scale=2.0, - >>> model_cens="uniform", cens_par=1.0) - >>> } - >>> - >>> # Compare hazard functions - >>> fig, ax = plot_hazard_comparison(models) - >>> plt.show() - """ - # Import lifelines here to avoid making it a hard dependency - try: - from lifelines import NelsonAalenFitter - except ImportError as exc: - raise ImportError( - "This function requires the lifelines package. " - "Install it with: pip install lifelines" - ) from exc - - fig, ax = plt.subplots(figsize=figsize) - - for model_name, df in models.items(): - naf = NelsonAalenFitter() - naf.fit(df[time_col], df[status_col]) - - # Get smoothed hazard estimate - hazard = naf.smoothed_hazard_(bandwidth=bandwidth) - - # Plot hazard function - ax.plot(hazard.index, hazard.values, label=model_name, alpha=0.8) - - # Customize plot appearance - ax.set_title(title) - ax.set_xlabel("Time") - ax.set_ylabel("Hazard Rate") - ax.grid(alpha=0.3) - ax.legend() - - plt.tight_layout() - return fig, ax - - -def plot_covariate_effect( - data: pd.DataFrame, - covariate_col: str, - time_col: str = "time", - status_col: str = "status", - n_groups: int = 3, - title: str = "Effect of Covariate on Survival", - figsize: tuple[float, float] = (10, 6), - ci_alpha: float = 0.2, -) -> tuple[Figure, Axes]: - """ - Visualize the effect of a continuous covariate on survival by discretizing - it. - - Parameters - ---------- - data : pd.DataFrame - DataFrame containing the survival data. - covariate_col : str - Name of the covariate column to visualize. - time_col : str, default="time" - Name of the column containing event/censoring times. - status_col : str, default="status" - Name of the column containing event indicators (1=event, 0=censored). - n_groups : int, default=3 - Number of groups to divide the covariate into (e.g., 3 for tertiles). - title : str, default="Effect of Covariate on Survival" - Plot title. - figsize : tuple, default=(10, 6) - Figure size (width, height) in inches. - ci_alpha : float, default=0.2 - Transparency level for confidence interval bands. - - Returns - ------- - fig : Figure - Matplotlib figure object. - ax : Axes - Matplotlib axes object. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.visualization import plot_covariate_effect - >>> - >>> # Generate data with a continuous covariate - >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> - >>> # Visualize the effect of the covariate on survival - >>> fig, ax = plot_covariate_effect(df, covariate_col="covariate", n_groups=3) - >>> plt.show() - """ - # Add a categorical version of the covariate - group_labels = [f"Q{i + 1}" for i in range(n_groups)] - data = data.copy() - data["_group"] = pd.qcut(data[covariate_col], q=n_groups, labels=group_labels) - - # Get the median value of each group for the legend - group_medians = data.groupby("_group")[covariate_col].median() - - # Create more informative labels - label_map = { - group: f"{group} ({covariate_col}β‰ˆ{median:.2f})" - for group, median in group_medians.items() - } - - data["_label"] = data["_group"].map(label_map) - - # Create the plot - fig, ax = plot_survival_curve( - data=data, - time_col=time_col, - status_col=status_col, - group_col="_label", - confidence_intervals=True, - title=title, - figsize=figsize, - ci_alpha=ci_alpha, - ) - - return fig, ax - - -def describe_survival( - data: pd.DataFrame, time_col: str = "time", status_col: str = "status" -) -> pd.DataFrame: - """ - Generate a summary of survival data including median survival time, - event counts, and other descriptive statistics. - - Parameters - ---------- - data : pd.DataFrame - DataFrame containing the survival data. - time_col : str, default="time" - Name of the column containing event/censoring times. - status_col : str, default="status" - Name of the column containing event indicators (1=event, 0=censored). - - Returns - ------- - pd.DataFrame - Summary statistics dataframe. - - Examples - -------- - >>> from gen_surv import generate - >>> from gen_surv.visualization import describe_survival - >>> - >>> # Generate data - >>> df = generate(model="cphm", n=200, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0) - >>> - >>> # Get survival summary - >>> summary = describe_survival(df) - >>> print(summary) - """ - # Import lifelines here to avoid making it a hard dependency - try: - from lifelines import KaplanMeierFitter - except ImportError as exc: - raise ImportError( - "This function requires the lifelines package. " - "Install it with: pip install lifelines" - ) from exc - - n_total = len(data) - n_events = data[status_col].sum() - n_censored = n_total - n_events - event_rate = n_events / n_total - - # Calculate median and other percentiles - kmf = KaplanMeierFitter() - kmf.fit(data[time_col], data[status_col]) - median = kmf.median_survival_time_ - - # Time ranges - time_min = data[time_col].min() - time_max = data[time_col].max() - time_mean = data[time_col].mean() - - # Create summary DataFrame - summary = pd.DataFrame( - { - "Metric": [ - "Total Observations", - "Number of Events", - "Number Censored", - "Event Rate", - "Median Survival Time", - "Min Time", - "Max Time", - "Mean Time", - ], - "Value": [ - n_total, - n_events, - n_censored, - f"{event_rate:.2%}", - f"{median:.4f}", - f"{time_min:.4f}", - f"{time_max:.4f}", - f"{time_mean:.4f}", - ], - } - ) - - return summary diff --git a/pyproject.toml b/pyproject.toml index 3b1c1af..5048a06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,72 +8,27 @@ readme = "README.md" packages = [{ include = "gen_surv" }] homepage = "https://github.com/DiogoRibeiro7/genSurvPy" repository = "https://github.com/DiogoRibeiro7/genSurvPy" -documentation = "https://gensurvpy.readthedocs.io/en/latest/" -keywords = ["survival-analysis", "simulation", "cox-model", "markov-model", "time-dependent", "statistics"] -classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Science/Research", - "Intended Audience :: Healthcare Industry", - "Topic :: Scientific/Engineering :: Medical Science Apps.", - "Topic :: Scientific/Engineering :: Mathematics", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: MIT License", -] +documentation = "https://gensurvpy.readthedocs.io/en/stable/" [tool.poetry.dependencies] -python = ">=3.10,<3.13" +python = "^3.9" numpy = "^1.26" pandas = "^2.2.3" +pytest-cov = "^6.1.1" +invoke = "^2.2.0" typer = "^0.12.3" -matplotlib = "~3.8" -lifelines = "^0.30" -pyarrow = "^14" -pyreadr = "^0.5" +tomli = "^2.2.1" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" -pytest-cov = "^6.1.1" python-semantic-release = "^9.21.0" mypy = "^1.15.0" invoke = "^2.2.0" hypothesis = "^6.98" tomli = "^2.2.1" -black = "^24.1.0" -isort = "^5.13.2" -flake8 = "^6.1.0" -scikit-survival = "^0.24.1" -pre-commit = "^3.8" - -[tool.poetry.extras] -dev = [ - "pytest", - "pytest-cov", - "python-semantic-release", - "mypy", - "invoke", - "hypothesis", - "tomli", - "black", - "isort", - "flake8", - "scikit-survival", - "pre-commit", -] [tool.poetry.group.docs.dependencies] -sphinx = ">=6.0" -sphinx-rtd-theme = "^1.3.0" -myst-parser = ">=1.0.0,<4.0.0" -sphinx-autodoc-typehints = "^1.24.0" -sphinx-copybutton = "^0.5.2" -sphinx-design = "^0.5.0" -linkify-it-py = ">=2.0" - -[tool.poetry.scripts] -gen_surv = "gen_surv.cli:app" +myst-parser = "<4.0.0" [tool.semantic_release] version_source = "tag" @@ -84,48 +39,6 @@ upload_to_repository = false branch = "main" build_command = "" -[tool.black] -line-length = 88 -target-version = ['py310'] -include = '\.pyi?$' - -[tool.isort] -profile = "black" -line_length = 88 - -[tool.flake8] -max-line-length = 88 -extend-ignore = ["E203", "W503", "E501", "W291", "W293", "W391", "E402", "E302", "E305"] - -[tool.mypy] -python_version = "3.10" -warn_return_any = true -warn_unused_configs = true -disallow_untyped_defs = true -disallow_incomplete_defs = true - -[[tool.mypy.overrides]] -module = [ - "typer", - "typer.models", - "matplotlib", - "matplotlib.*", - "lifelines", - "lifelines.*", - "sklearn", - "sklearn.*", - "numpy", - "numpy.*", - "pandas", - "pandas.*", - "scipy", - "scipy.*", - "pyreadr", - "sksurv", - "sksurv.*", -] -ignore_missing_imports = true - [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/scripts/check_version_match.py b/scripts/check_version_match.py index 36cf8f0..9c3b7dd 100755 --- a/scripts/check_version_match.py +++ b/scripts/check_version_match.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 """Check that pyproject version matches the latest git tag. Optionally fix it by tagging.""" +from pathlib import Path import subprocess import sys -from pathlib import Path -from typing import Any, cast if sys.version_info >= (3, 11): import tomllib as tomli @@ -12,13 +11,11 @@ ROOT = Path(__file__).resolve().parents[1] - def pyproject_version() -> str: pyproject_path = ROOT / "pyproject.toml" with pyproject_path.open("rb") as f: - data: Any = tomli.load(f) - return cast(str, data["tool"]["poetry"]["version"]) - + data = tomli.load(f) + return data["tool"]["poetry"]["version"] def latest_tag() -> str: try: @@ -29,14 +26,12 @@ def latest_tag() -> str: except subprocess.CalledProcessError: return "" - def create_tag(version: str) -> None: print(f"Tagging repository with version: v{version}") subprocess.run(["git", "tag", f"v{version}"], cwd=ROOT, check=True) subprocess.run(["git", "push", "origin", f"v{version}"], cwd=ROOT, check=True) print(f"βœ… Git tag v{version} created and pushed.") - def main() -> int: fix = "--fix" in sys.argv version = pyproject_version() @@ -63,6 +58,5 @@ def main() -> int: print(f"βœ”οΈ Version matches latest tag: {version}") return 0 - if __name__ == "__main__": sys.exit(main()) diff --git a/tasks.py b/tasks.py index 1a93900..d37bc3a 100644 --- a/tasks.py +++ b/tasks.py @@ -1,6 +1,9 @@ +from invoke.tasks import task +from invoke import Context, task +from typing import Any import shlex -from invoke import Context, task + @task @@ -22,10 +25,13 @@ def test(c: Context) -> None: # Build the command string. You can adjust '--cov=gen_surv' if you # need to cover a different package or add extra pytest flags. command = ( - "poetry run pytest " "--cov=gen_surv " "--cov-report=term " "--cov-report=xml" + "poetry run pytest " + "--cov=gen_surv " + "--cov-report=term " + "--cov-report=xml" ) - # Run pytest. + # Run pytest. # - warn=True: capture non-zero exit codes without aborting Invoke. # - pty=False: pytest doesn’t require an interactive TTY here. result = c.run(command, warn=True, pty=False) @@ -35,15 +41,9 @@ def test(c: Context) -> None: print("βœ”οΈ All tests passed.") else: print("❌ Some tests failed.") - exit_code = ( - result.exited - if result is not None and hasattr(result, "exited") - else "Unknown" - ) + exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" print(f"Exit code: {exit_code}") - stderr_output = ( - result.stderr if result is not None and hasattr(result, "stderr") else None - ) + stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None if stderr_output: print("Error output:") print(stderr_output) @@ -76,7 +76,6 @@ def checkversion(c: Context) -> None: print("❌ Version mismatch detected.") print(result.stderr) - @task def docs(c: Context) -> None: """Build the Sphinx documentation. @@ -104,15 +103,9 @@ def docs(c: Context) -> None: print("βœ”οΈ Documentation built successfully.") else: print("❌ Documentation build failed.") - exit_code = ( - result.exited - if result is not None and hasattr(result, "exited") - else "Unknown" - ) + exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" print(f"Exit code: {exit_code}") - stderr_output = ( - result.stderr if result is not None and hasattr(result, "stderr") else None - ) + stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None if stderr_output: print("Error output:") print(stderr_output) @@ -145,15 +138,9 @@ def stubs(c: Context) -> None: print("βœ”οΈ Type stubs generated successfully in 'stubs/'.") else: print("❌ Stub generation failed.") - exit_code = ( - result.exited - if result is not None and hasattr(result, "exited") - else "Unknown" - ) + exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" print(f"Exit code: {exit_code}") - stderr_output = ( - result.stderr if result is not None and hasattr(result, "stderr") else None - ) + stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None if stderr_output: print("Error output:") print(stderr_output) @@ -183,25 +170,16 @@ def build(c: Context) -> None: # Report the result of the build process. if result is not None and getattr(result, "ok", False): - print( - "βœ”οΈ Build completed successfully. Artifacts are in the 'dist/' directory." - ) + print("βœ”οΈ Build completed successfully. Artifacts are in the 'dist/' directory.") else: print("❌ Build failed.") - exit_code = ( - result.exited - if result is not None and hasattr(result, "exited") - else "Unknown" - ) + exit_code = result.exited if result is not None and hasattr(result, "exited") else "Unknown" print(f"Exit code: {exit_code}") - stderr_output = ( - result.stderr if result is not None and hasattr(result, "stderr") else None - ) + stderr_output = result.stderr if result is not None and hasattr(result, "stderr") else None if stderr_output: print("Error output:") print(stderr_output) - @task def publish(c: Context) -> None: """Build and upload the package to PyPI. @@ -232,7 +210,6 @@ def publish(c: Context) -> None: else: print("No stderr output captured.") - @task def clean(c: Context) -> None: """Remove build artifacts and caches. @@ -277,8 +254,7 @@ def clean(c: Context) -> None: if result.stderr: print("Error output:") print(result.stderr) - - + @task def gitpush(c: Context) -> None: """Commit and push all staged changes. @@ -297,17 +273,9 @@ def gitpush(c: Context) -> None: result_add = c.run("git add .", warn=True, pty=False) if result_add is None or not getattr(result_add, "ok", False): print("❌ Failed to stage changes (git add).") - exit_code = ( - result_add.exited - if result_add is not None and hasattr(result_add, "exited") - else "Unknown" - ) + exit_code = result_add.exited if result_add is not None and hasattr(result_add, "exited") else "Unknown" print(f"Exit code: {exit_code}") - stderr_output = ( - result_add.stderr - if result_add is not None and hasattr(result_add, "stderr") - else None - ) + stderr_output = result_add.stderr if result_add is not None and hasattr(result_add, "stderr") else None if stderr_output: print("Error output:") print(stderr_output) @@ -345,17 +313,9 @@ def gitpush(c: Context) -> None: print("βœ”οΈ Changes pushed successfully.") else: print("❌ Push failed.") - exit_code = ( - getattr(result_push, "exited", "Unknown") - if result_push is not None - else "Unknown" - ) + exit_code = getattr(result_push, "exited", "Unknown") if result_push is not None else "Unknown" print(f"Exit code: {exit_code}") - stderr_output = ( - getattr(result_push, "stderr", None) - if result_push is not None - else None - ) + stderr_output = getattr(result_push, "stderr", None) if result_push is not None else None if stderr_output: print("Error output:") print(stderr_output) diff --git a/tests/test_aft.py b/tests/test_aft.py index 1484432..2688144 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -1,157 +1,14 @@ -""" -Tests for Accelerated Failure Time (AFT) models. -""" - import pandas as pd -import pytest -from hypothesis import given -from hypothesis import strategies as st - -from gen_surv.aft import gen_aft_log_logistic, gen_aft_log_normal, gen_aft_weibull -from gen_surv.validation import ChoiceError, PositiveValueError - - -def test_gen_aft_log_logistic_runs(): - """Test that the Log-Logistic AFT generator runs without errors.""" - df = gen_aft_log_logistic( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ) - assert isinstance(df, pd.DataFrame) - assert not df.empty - assert "time" in df.columns - assert "status" in df.columns - assert "X0" in df.columns - assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1}) - - -def test_gen_aft_log_logistic_invalid_shape(): - """Test that the Log-Logistic AFT generator raises error - for invalid shape.""" - with pytest.raises(PositiveValueError): - gen_aft_log_logistic( - n=10, - beta=[0.5, -0.2], - shape=-1.0, # Invalid negative shape - scale=2.0, - model_cens="uniform", - cens_par=5.0, - ) - - -def test_gen_aft_log_logistic_invalid_scale(): - """Test that the Log-Logistic AFT generator raises error - for invalid scale.""" - with pytest.raises(PositiveValueError): - gen_aft_log_logistic( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=0.0, # Invalid zero scale - model_cens="uniform", - cens_par=5.0, - ) - - -@given( - n=st.integers(min_value=1, max_value=20), - shape=st.floats( - min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False - ), - scale=st.floats( - min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False - ), - cens_par=st.floats( - min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False - ), - seed=st.integers(min_value=0, max_value=1000), -) -def test_gen_aft_log_logistic_properties(n, shape, scale, cens_par, seed): - """Property-based test for the Log-Logistic AFT generator.""" - df = gen_aft_log_logistic( - n=n, - beta=[0.5, -0.2], - shape=shape, - scale=scale, - model_cens="uniform", - cens_par=cens_par, - seed=seed, - ) - assert df.shape[0] == n - assert set(df["status"].unique()).issubset({0, 1}) - assert (df["time"] >= 0).all() - assert df.filter(regex="^X[0-9]+$").shape[1] == 2 - - -def test_gen_aft_log_logistic_reproducibility(): - """Test that the Log-Logistic AFT generator is reproducible - with the same seed.""" - df1 = gen_aft_log_logistic( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ) - - df2 = gen_aft_log_logistic( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ) - - pd.testing.assert_frame_equal(df1, df2) - - df3 = gen_aft_log_logistic( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=43, # Different seed - ) - - with pytest.raises(AssertionError): - pd.testing.assert_frame_equal(df1, df3) - +from gen_surv.aft import gen_aft_log_normal def test_gen_aft_log_normal_runs(): - """Test that the log-normal AFT generator runs without errors.""" df = gen_aft_log_normal( - n=10, beta=[0.5, -0.2], sigma=1.0, model_cens="uniform", cens_par=5.0, seed=42 - ) - assert isinstance(df, pd.DataFrame) - assert not df.empty - assert "time" in df.columns - assert "status" in df.columns - assert "X0" in df.columns - assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1}) - - -def test_gen_aft_weibull_runs(): - """Test that the Weibull AFT generator runs without errors.""" - df = gen_aft_weibull( n=10, beta=[0.5, -0.2], - shape=1.5, - scale=2.0, + sigma=1.0, model_cens="uniform", cens_par=5.0, - seed=42, + seed=42 ) assert isinstance(df, pd.DataFrame) assert not df.empty @@ -159,111 +16,4 @@ def test_gen_aft_weibull_runs(): assert "status" in df.columns assert "X0" in df.columns assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1}) - - -def test_gen_aft_weibull_invalid_shape(): - """Test that the Weibull AFT generator raises error for invalid shape.""" - with pytest.raises(PositiveValueError): - gen_aft_weibull( - n=10, - beta=[0.5, -0.2], - shape=-1.0, # Invalid negative shape - scale=2.0, - model_cens="uniform", - cens_par=5.0, - ) - - -def test_gen_aft_weibull_invalid_scale(): - """Test that the Weibull AFT generator raises error for invalid scale.""" - with pytest.raises(PositiveValueError): - gen_aft_weibull( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=0.0, # Invalid zero scale - model_cens="uniform", - cens_par=5.0, - ) - - -def test_gen_aft_weibull_invalid_cens_model(): - """Test that the Weibull AFT generator raises error for invalid censoring model.""" - with pytest.raises(ChoiceError): - gen_aft_weibull( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="invalid", # Invalid censoring model - cens_par=5.0, - ) - - -@given( - n=st.integers(min_value=1, max_value=20), - shape=st.floats( - min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False - ), - scale=st.floats( - min_value=0.1, max_value=5.0, allow_nan=False, allow_infinity=False - ), - cens_par=st.floats( - min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False - ), - seed=st.integers(min_value=0, max_value=1000), -) -def test_gen_aft_weibull_properties(n, shape, scale, cens_par, seed): - """Property-based test for the Weibull AFT generator.""" - df = gen_aft_weibull( - n=n, - beta=[0.5, -0.2], - shape=shape, - scale=scale, - model_cens="uniform", - cens_par=cens_par, - seed=seed, - ) - assert df.shape[0] == n - assert set(df["status"].unique()).issubset({0, 1}) - assert (df["time"] >= 0).all() - assert df.filter(regex="^X[0-9]+$").shape[1] == 2 - - -def test_gen_aft_weibull_reproducibility(): - """Test that the Weibull AFT generator is reproducible with the same seed.""" - df1 = gen_aft_weibull( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ) - - df2 = gen_aft_weibull( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=42, - ) - - pd.testing.assert_frame_equal(df1, df2) - - df3 = gen_aft_weibull( - n=10, - beta=[0.5, -0.2], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=5.0, - seed=43, # Different seed - ) - - with pytest.raises(AssertionError): - pd.testing.assert_frame_equal(df1, df3) + assert set(df["status"].unique()).issubset({0, 1}) \ No newline at end of file diff --git a/tests/test_aft_property.py b/tests/test_aft_property.py index 73a4061..18a6412 100644 --- a/tests/test_aft_property.py +++ b/tests/test_aft_property.py @@ -1,18 +1,11 @@ -from hypothesis import given -from hypothesis import strategies as st - +from hypothesis import given, strategies as st from gen_surv.aft import gen_aft_log_normal - @given( n=st.integers(min_value=1, max_value=20), - sigma=st.floats( - min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False - ), - cens_par=st.floats( - min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False - ), - seed=st.integers(min_value=0, max_value=1000), + sigma=st.floats(min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False), + cens_par=st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False), + seed=st.integers(min_value=0, max_value=1000) ) def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): df = gen_aft_log_normal( @@ -21,7 +14,7 @@ def test_gen_aft_log_normal_properties(n, sigma, cens_par, seed): sigma=sigma, model_cens="uniform", cens_par=cens_par, - seed=seed, + seed=seed ) assert df.shape[0] == n assert set(df["status"].unique()).issubset({0, 1}) diff --git a/tests/test_bivariate.py b/tests/test_bivariate.py index df0bd5d..b011130 100644 --- a/tests/test_bivariate.py +++ b/tests/test_bivariate.py @@ -1,8 +1,6 @@ import numpy as np -import pytest - from gen_surv.bivariate import sample_bivariate_distribution -from gen_surv.validation import ChoiceError, LengthError +import pytest def test_sample_bivariate_exponential_shape(): @@ -13,18 +11,16 @@ def test_sample_bivariate_exponential_shape(): def test_sample_bivariate_invalid_dist(): - """Unsupported distributions should raise ChoiceError.""" - with pytest.raises(ChoiceError): + """Unsupported distributions should raise ValueError.""" + with pytest.raises(ValueError): sample_bivariate_distribution(10, "invalid", 0.0, [1, 1]) - def test_sample_bivariate_exponential_param_length_error(): - """Exponential distribution with wrong param length should raise LengthError.""" - with pytest.raises(LengthError): + """Exponential distribution with wrong param length should raise ValueError.""" + with pytest.raises(ValueError): sample_bivariate_distribution(5, "exponential", 0.0, [1.0]) - def test_sample_bivariate_weibull_param_length_error(): - """Weibull distribution with wrong param length should raise LengthError.""" - with pytest.raises(LengthError): + """Weibull distribution with wrong param length should raise ValueError.""" + with pytest.raises(ValueError): sample_bivariate_distribution(5, "weibull", 0.0, [1.0, 1.0]) diff --git a/tests/test_censoring.py b/tests/test_censoring.py deleted file mode 100644 index 33ce3d5..0000000 --- a/tests/test_censoring.py +++ /dev/null @@ -1,72 +0,0 @@ -import numpy as np - -from gen_surv.censoring import ( - GammaCensoring, - LogNormalCensoring, - WeibullCensoring, - rexpocens, - rgammacens, - rlognormcens, - runifcens, - rweibcens, -) - - -def test_runifcens_range(): - times = runifcens(5, 2.0) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times >= 0) - assert np.all(times <= 2.0) - - -def test_rexpocens_nonnegative(): - times = rexpocens(5, 2.0) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times >= 0) - - -def test_rweibcens_nonnegative(): - times = rweibcens(5, 1.0, 1.5) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times >= 0) - - -def test_rlognormcens_positive(): - times = rlognormcens(5, 0.0, 1.0) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times > 0) - - -def test_rgammacens_positive(): - times = rgammacens(5, 2.0, 1.0) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times > 0) - - -def test_weibull_censoring_class(): - model = WeibullCensoring(scale=1.0, shape=1.5) - times = model(5) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times >= 0) - - -def test_lognormal_censoring_class(): - model = LogNormalCensoring(mean=0.0, sigma=1.0) - times = model(5) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times > 0) - - -def test_gamma_censoring_class(): - model = GammaCensoring(shape=2.0, scale=1.0) - times = model(5) - assert isinstance(times, np.ndarray) - assert len(times) == 5 - assert np.all(times > 0) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5d6f321..d5fd8d7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,21 +1,12 @@ -import runpy -import sys -from typing import Any - import pandas as pd -import pytest - -from gen_surv.cli import dataset, visualize +from gen_surv.cli import dataset +import runpy def test_cli_dataset_stdout(monkeypatch, capsys): - """ - Test that the 'dataset' CLI command prints the generated CSV data to stdout when no output file is specified. - This test patches the 'generate' function to return a simple DataFrame, invokes the CLI command directly, - and asserts that the expected CSV header appears in the captured standard output. - """ + """Dataset command prints CSV to stdout when no output file is given.""" - def fake_generate(**_: Any): + def fake_generate(model: str, n: int): return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) # Patch the generate function used in the CLI to avoid heavy computation. @@ -36,15 +27,14 @@ def fake_app(): # Patch the CLI app before the module is executed monkeypatch.setattr("gen_surv.cli.app", fake_app) - monkeypatch.setattr(sys, "argv", ["gen_surv", "dataset", "cphm"]) + monkeypatch.setattr("sys.argv", ["gen_surv", "dataset", "cphm"]) runpy.run_module("gen_surv.__main__", run_name="__main__") assert called - def test_cli_dataset_file_output(monkeypatch, tmp_path): """Dataset command writes CSV to file when output path is provided.""" - def fake_generate(**_: Any): + def fake_generate(model: str, n: int): return pd.DataFrame({"time": [1.0], "status": [1], "X0": [0.1], "X1": [0.2]}) monkeypatch.setattr("gen_surv.cli.generate", fake_generate) @@ -53,187 +43,3 @@ def fake_generate(**_: Any): assert out_file.exists() content = out_file.read_text() assert "time,status,X0,X1" in content - - -def test_dataset_weibull_parameters(monkeypatch): - """Parameters for aft_weibull model are forwarded correctly.""" - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [0]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset( - model="aft_weibull", n=3, beta=[0.1, 0.2], shape=1.1, scale=2.2, output=None - ) - assert captured["model"] == "aft_weibull" - assert captured["beta"] == [0.1, 0.2] - assert captured["shape"] == 1.1 - assert captured["scale"] == 2.2 - - -def test_dataset_aft_ln(monkeypatch): - """aft_ln model should forward beta list and sigma.""" - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [1]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset(model="aft_ln", n=1, beta=[0.3, 0.4], sigma=1.2, output=None) - assert captured["beta"] == [0.3, 0.4] - assert captured["sigma"] == 1.2 - - -def test_dataset_competing_risks(monkeypatch): - """competing_risks expands betas and passes hazards.""" - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [1]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset( - model="competing_risks", - n=1, - n_risks=2, - baseline_hazards=[0.1, 0.2], - beta=0.5, - output=None, - ) - assert captured["n_risks"] == 2 - assert captured["baseline_hazards"] == [0.1, 0.2] - assert captured["betas"] == [0.5, 0.5] - - -def test_dataset_mixture_cure(monkeypatch): - """mixture_cure passes cure and baseline parameters.""" - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [1]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset( - model="mixture_cure", - n=1, - cure_fraction=0.2, - baseline_hazard=0.1, - beta=[0.4], - output=None, - ) - assert captured["cure_fraction"] == 0.2 - assert captured["baseline_hazard"] == 0.1 - assert captured["betas_survival"] == [0.4] - assert captured["betas_cure"] == [0.4] - - -def test_dataset_invalid_model(monkeypatch): - def fake_generate(**kwargs): - raise ValueError("bad model") - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - with pytest.raises(ValueError): - dataset(model="nope", n=1, output=None) - - -def test_cli_visualize_basic(monkeypatch, tmp_path): - csv = tmp_path / "data.csv" - pd.DataFrame({"time": [1, 2], "status": [1, 0]}).to_csv(csv, index=False) - - def fake_plot_survival_curve(**kwargs): - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - ax.plot([0, 1], [1, 0]) - return fig, ax - - monkeypatch.setattr( - "gen_surv.visualization.plot_survival_curve", fake_plot_survival_curve - ) - - saved = [] - - def fake_savefig(path, *args, **kwargs): - saved.append(path) - - monkeypatch.setattr("matplotlib.pyplot.savefig", fake_savefig) - - visualize( - str(csv), - time_col="time", - status_col="status", - group_col=None, - output=str(tmp_path / "plot.png"), - ) - assert saved and saved[0].endswith("plot.png") - - -def test_dataset_aft_log_logistic(monkeypatch): - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [1]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset( - model="aft_log_logistic", - n=1, - beta=[0.1], - shape=1.2, - scale=2.3, - output=None, - ) - assert captured["model"] == "aft_log_logistic" - assert captured["beta"] == [0.1] - assert captured["shape"] == 1.2 - assert captured["scale"] == 2.3 - - -def test_dataset_competing_risks_weibull(monkeypatch): - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [1]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset( - model="competing_risks_weibull", - n=1, - n_risks=2, - shape_params=[0.7, 1.2], - scale_params=[2.0, 2.0], - beta=0.3, - output=None, - ) - assert captured["n_risks"] == 2 - assert captured["shape_params"] == [0.7, 1.2] - assert captured["scale_params"] == [2.0, 2.0] - assert captured["betas"] == [0.3, 0.3] - - -def test_dataset_piecewise(monkeypatch): - captured = {} - - def fake_generate(**kwargs): - captured.update(kwargs) - return pd.DataFrame({"time": [1], "status": [1]}) - - monkeypatch.setattr("gen_surv.cli.generate", fake_generate) - dataset( - model="piecewise_exponential", - n=1, - breakpoints=[1.0], - hazard_rates=[0.2, 0.3], - beta=[0.4], - output=None, - ) - assert captured["breakpoints"] == [1.0] - assert captured["hazard_rates"] == [0.2, 0.3] - assert captured["betas"] == [0.4] diff --git a/tests/test_cli_integration.py b/tests/test_cli_integration.py deleted file mode 100644 index 5012f1e..0000000 --- a/tests/test_cli_integration.py +++ /dev/null @@ -1,30 +0,0 @@ -import pandas as pd -from typer.testing import CliRunner - -from gen_surv.cli import app - - -def test_dataset_cli_integration(tmp_path): - """Run dataset command end-to-end and verify CSV output.""" - runner = CliRunner() - out_file = tmp_path / "data.csv" - result = runner.invoke( - app, - [ - "dataset", - "cphm", - "--n", - "3", - "--beta", - "0.5", - "--covariate-range", - "1.0", - "-o", - str(out_file), - ], - ) - assert result.exit_code == 0 - assert out_file.exists() - df = pd.read_csv(out_file) - assert len(df) == 3 - assert {"time", "status"}.issubset(df.columns) diff --git a/tests/test_cmm.py b/tests/test_cmm.py index 2b6c4d0..48f2467 100644 --- a/tests/test_cmm.py +++ b/tests/test_cmm.py @@ -1,86 +1,10 @@ -import numpy as np -import pandas as pd - -from gen_surv.cmm import gen_cmm, generate_event_times - - -def test_generate_event_times_reproducible(): - np.random.seed(0) - result = generate_event_times( - z1=1.0, - beta=[0.1, 0.2, 0.3], - rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - ) - assert np.isclose(result["t12"], 0.7201370350469476) - assert np.isclose(result["t13"], 1.0282691393768246) - assert np.isclose(result["t23"], 0.6839405281667484) - - -def test_gen_cmm_uniform_reproducible(): - df = gen_cmm( - n=5, - model_cens="uniform", - cens_par=1.0, - beta=[0.1, 0.2, 0.3], - covariate_range=2.0, - rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - seed=42, - ) - expected = pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "start": [0.0] * 5, - "stop": [ - 0.18915094163423693, - 0.6785349983450479, - 0.046776460564183294, - 0.12811363267554587, - 0.45038631001973155, - ], - "status": [1, 1, 1, 0, 0], - "X0": [ - 1.5479119272347037, - 0.8777564989945617, - 1.7171958398225217, - 1.3947360581187287, - 0.1883555828087116, - ], - "transition": [2.0, 2.0, 2.0, float("nan"), float("nan")], - } - ) - pd.testing.assert_frame_equal(df, expected) - - -def test_gen_cmm_exponential_reproducible(): - df = gen_cmm( - n=5, - model_cens="exponential", - cens_par=1.0, - beta=[0.1, 0.2, 0.3], - covariate_range=2.0, - rate=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - seed=42, - ) - expected = pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "start": [0.0] * 5, - "stop": [ - 0.18915094163423693, - 0.6785349983450479, - 0.046776460564183294, - 0.07929383504134148, - 0.5750008479681584, - ], - "status": [1, 1, 1, 0, 1], - "X0": [ - 1.5479119272347037, - 0.8777564989945617, - 1.7171958398225217, - 1.3947360581187287, - 0.1883555828087116, - ], - "transition": [2.0, 2.0, 2.0, float("nan"), 1.0], - } - ) - pd.testing.assert_frame_equal(df, expected) +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from gen_surv.cmm import gen_cmm + +def test_gen_cmm_shape(): + df = gen_cmm(n=50, model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], + covar=2.0, rate=[0.1, 1.0, 0.2, 1.0, 0.3, 1.0]) + assert df.shape[1] == 6 + assert "transition" in df.columns diff --git a/tests/test_competing_risks.py b/tests/test_competing_risks.py deleted file mode 100644 index d2eb326..0000000 --- a/tests/test_competing_risks.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for Competing Risks models.""" - -import os - -import numpy as np -import pandas as pd -import pytest -from hypothesis import given -from hypothesis import strategies as st - -import gen_surv.competing_risks as cr -from gen_surv.competing_risks import ( - cause_specific_cumulative_incidence, - gen_competing_risks, - gen_competing_risks_weibull, -) -from gen_surv.validation import ( - ChoiceError, - LengthError, - NumericSequenceError, - ParameterError, - PositiveSequenceError, - PositiveValueError, -) - -os.environ.setdefault("MPLBACKEND", "Agg") - - -def test_gen_competing_risks_basic(): - """Test that the competing risks generator runs without errors.""" - df = gen_competing_risks( - n=10, - n_risks=2, - baseline_hazards=[0.5, 0.3], - betas=[[0.8, -0.5], [0.2, 0.7]], - model_cens="uniform", - cens_par=2.0, - seed=42, - ) - assert isinstance(df, pd.DataFrame) - assert not df.empty - assert "time" in df.columns - assert "status" in df.columns - assert "X0" in df.columns - assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1, 2}) - - -def test_gen_competing_risks_weibull_basic(): - """Test that the Weibull competing risks generator runs without errors.""" - df = gen_competing_risks_weibull( - n=10, - n_risks=2, - shape_params=[0.8, 1.5], - scale_params=[2.0, 3.0], - betas=[[0.8, -0.5], [0.2, 0.7]], - model_cens="uniform", - cens_par=2.0, - seed=42, - ) - assert isinstance(df, pd.DataFrame) - assert not df.empty - assert "time" in df.columns - assert "status" in df.columns - assert "X0" in df.columns - assert "X1" in df.columns - assert set(df["status"].unique()).issubset({0, 1, 2}) - - -def test_competing_risks_parameters(): - """Test parameter validation in competing risks model.""" - # Test with invalid number of baseline hazards - with pytest.raises(LengthError): - gen_competing_risks( - n=10, - n_risks=3, - baseline_hazards=[0.5, 0.3], # Only 2 provided, but 3 risks - seed=42, - ) - - # Test with invalid number of beta coefficient sets - with pytest.raises(LengthError): - gen_competing_risks( - n=10, - n_risks=2, - betas=[[0.8, -0.5]], # Only 1 set provided, but 2 risks - seed=42, - ) - - # Test with invalid censoring model - with pytest.raises(ChoiceError): - gen_competing_risks(n=10, n_risks=2, model_cens="invalid", seed=42) - - -def test_invalid_covariate_dist(): - with pytest.raises(ChoiceError): - gen_competing_risks(n=5, n_risks=2, covariate_dist="unknown", seed=1) - with pytest.raises(ChoiceError): - gen_competing_risks_weibull(n=5, n_risks=2, covariate_dist="unknown", seed=1) - - -def test_competing_risks_positive_params(): - with pytest.raises(PositiveValueError): - gen_competing_risks(n=5, n_risks=2, cens_par=0.0, seed=0) - with pytest.raises(PositiveValueError): - gen_competing_risks(n=5, n_risks=2, max_time=-1.0, seed=0) - - -def test_competing_risks_invalid_covariate_params(): - with pytest.raises(ParameterError): - gen_competing_risks( - n=5, - n_risks=2, - covariate_dist="normal", - covariate_params={"mean": 0.0}, - seed=1, - ) - with pytest.raises(PositiveValueError): - gen_competing_risks( - n=5, - n_risks=2, - covariate_dist="normal", - covariate_params={"mean": 0.0, "std": -1.0}, - seed=1, - ) - with pytest.raises(ParameterError): - gen_competing_risks_weibull( - n=5, - n_risks=2, - covariate_dist="binary", - covariate_params={"p": 1.5}, - seed=1, - ) - - -def test_competing_risks_invalid_beta_values(): - with pytest.raises(NumericSequenceError): - gen_competing_risks(n=5, n_risks=2, betas=[[0.1, "x"], [0.2, 0.3]], seed=0) - with pytest.raises(NumericSequenceError): - gen_competing_risks_weibull( - n=5, n_risks=2, betas=[[0.1, np.nan], [0.2, 0.3]], seed=0 - ) - - -def test_competing_risks_weibull_parameters(): - """Test parameter validation in Weibull competing risks model.""" - # Test with invalid number of shape parameters - with pytest.raises(LengthError): - gen_competing_risks_weibull( - n=10, - n_risks=3, - shape_params=[0.8, 1.5], # Only 2 provided, but 3 risks - seed=42, - ) - - # Test with invalid number of scale parameters - with pytest.raises(LengthError): - gen_competing_risks_weibull( - n=10, - n_risks=3, - scale_params=[2.0, 3.0], # Only 2 provided, but 3 risks - seed=42, - ) - - -def test_competing_risks_weibull_positive_params(): - with pytest.raises(PositiveSequenceError): - gen_competing_risks_weibull(n=5, n_risks=2, shape_params=[1.0, -1.0], seed=0) - with pytest.raises(PositiveSequenceError): - gen_competing_risks_weibull(n=5, n_risks=2, scale_params=[2.0, 0.0], seed=0) - with pytest.raises(PositiveValueError): - gen_competing_risks_weibull(n=5, n_risks=2, cens_par=-1.0, seed=0) - with pytest.raises(PositiveValueError): - gen_competing_risks_weibull(n=5, n_risks=2, max_time=0.0, seed=0) - - -def test_cause_specific_cumulative_incidence(): - """Test the cause-specific cumulative incidence function.""" - # Generate some data - df = gen_competing_risks(n=50, n_risks=2, baseline_hazards=[0.5, 0.3], seed=42) - - # Calculate CIF for cause 1 - time_points = np.linspace(0, 5, 10) - cif = cause_specific_cumulative_incidence(df, time_points, cause=1) - - assert isinstance(cif, pd.DataFrame) - assert len(cif) == len(time_points) - assert "time" in cif.columns - assert "incidence" in cif.columns - assert (cif["incidence"] >= 0).all() - assert (cif["incidence"] <= 1).all() - assert cif["incidence"].is_monotonic_increasing - - # Test with invalid cause - with pytest.raises(ParameterError): - cause_specific_cumulative_incidence(df, time_points, cause=3) - - -def test_cause_specific_cumulative_incidence_handles_ties(): - df = pd.DataFrame( - { - "id": [0, 1, 2, 3], - "time": [1.0, 1.0, 2.0, 2.0], - "status": [1, 2, 1, 0], - } - ) - cif = cause_specific_cumulative_incidence(df, [1.0, 2.0], cause=1) - assert np.allclose(cif["incidence"].to_numpy(), [0.25, 0.5]) - - -def test_cause_specific_cumulative_incidence_bounds(): - df = gen_competing_risks(n=30, n_risks=2, seed=5) - max_time = df["time"].max() - time_points = [-1.0, 0.0, max_time + 1] - cif = cause_specific_cumulative_incidence(df, time_points, cause=1) - assert cif.iloc[0]["incidence"] == 0.0 - expected = cause_specific_cumulative_incidence(df, [max_time], cause=1).iloc[0][ - "incidence" - ] - assert cif.iloc[-1]["incidence"] == expected - - -@given( - n=st.integers(min_value=5, max_value=50), - n_risks=st.integers(min_value=2, max_value=4), - seed=st.integers(min_value=0, max_value=1000), -) -def test_competing_risks_properties(n, n_risks, seed): - """Property-based tests for the competing risks model.""" - df = gen_competing_risks(n=n, n_risks=n_risks, seed=seed) - - # Check basic properties - assert df.shape[0] == n - assert all(col in df.columns for col in ["id", "time", "status"]) - assert (df["time"] >= 0).all() - assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks - - # Count of each status - status_counts = df["status"].value_counts() - # There should be at least one of each status (including censoring) - # This might occasionally fail due to randomness, so we'll just check that - # we have at least 2 different status values - assert len(status_counts) >= 2 - - -@given( - n=st.integers(min_value=5, max_value=50), - n_risks=st.integers(min_value=2, max_value=4), - seed=st.integers(min_value=0, max_value=1000), -) -def test_competing_risks_weibull_properties(n, n_risks, seed): - """Property-based tests for the Weibull competing risks model.""" - df = gen_competing_risks_weibull(n=n, n_risks=n_risks, seed=seed) - - # Check basic properties - assert df.shape[0] == n - assert all(col in df.columns for col in ["id", "time", "status"]) - assert (df["time"] >= 0).all() - assert df["status"].isin(list(range(n_risks + 1))).all() # 0 to n_risks - - # Count of each status - status_counts = df["status"].value_counts() - # There should be at least 2 different status values - assert len(status_counts) >= 2 - - -def test_gen_competing_risks_forces_event_types(): - df = gen_competing_risks( - n=2, - n_risks=2, - baseline_hazards=[1e-9, 1e-9], - model_cens="uniform", - cens_par=0.1, - seed=0, - ) - assert set(df["status"]) == {1, 2} - - -def test_gen_competing_risks_weibull_forces_event_types(): - df = gen_competing_risks_weibull( - n=2, - n_risks=2, - shape_params=[1, 1], - scale_params=[1e9, 1e9], - model_cens="uniform", - cens_par=0.1, - seed=0, - ) - assert set(df["status"]) == {1, 2} - - -def test_reproducibility(): - """Test that results are reproducible with the same seed.""" - df1 = gen_competing_risks(n=20, n_risks=2, seed=42) - - df2 = gen_competing_risks(n=20, n_risks=2, seed=42) - - pd.testing.assert_frame_equal(df1, df2) - - # Different seeds should produce different results - df3 = gen_competing_risks(n=20, n_risks=2, seed=43) - - with pytest.raises(AssertionError): - pd.testing.assert_frame_equal(df1, df3) - - -def test_competing_risks_summary_basic(): - df = gen_competing_risks(n=10, n_risks=2, seed=1) - summary = cr.competing_risks_summary(df) - assert summary["n_subjects"] == 10 - assert summary["n_causes"] == 2 - assert set(summary["events_by_cause"]) <= {1, 2} - assert "time_stats" in summary - - -def test_competing_risks_summary_with_categorical(): - df = gen_competing_risks(n=8, n_risks=2, seed=2) - df["group"] = ["A", "B"] * 4 - summary = cr.competing_risks_summary(df, covariate_cols=["X0", "group"]) - assert summary["covariate_stats"]["group"]["categories"] == 2 - assert "distribution" in summary["covariate_stats"]["group"] - - -def test_plot_cause_specific_hazards_runs(): - plt = pytest.importorskip("matplotlib.pyplot") - df = gen_competing_risks(n=30, n_risks=2, seed=3) - fig, ax = cr.plot_cause_specific_hazards(df, time_points=np.linspace(0, 5, 5)) - assert hasattr(fig, "savefig") - assert len(ax.get_lines()) >= 1 - plt.close(fig) diff --git a/tests/test_cphm.py b/tests/test_cphm.py index f71d04c..05cc652 100644 --- a/tests/test_cphm.py +++ b/tests/test_cphm.py @@ -1,66 +1,13 @@ -""" -Tests for the Cox Proportional Hazards Model (CPHM) generator. -""" - -import pandas as pd -import pytest - +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.cphm import gen_cphm - def test_gen_cphm_output_shape(): - """Test that the output DataFrame has the expected shape and columns.""" - df = gen_cphm( - n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0 - ) + df = gen_cphm(n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covar=2.0) assert df.shape == (50, 3) - assert list(df.columns) == ["time", "status", "X0"] - + assert list(df.columns) == ["time", "status", "covariate"] def test_gen_cphm_status_range(): - """Test that status values are binary (0 or 1).""" - df = gen_cphm( - n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covariate_range=1.5 - ) + df = gen_cphm(n=100, model_cens="exponential", cens_par=0.8, beta=0.3, covar=1.5) assert df["status"].isin([0, 1]).all() - - -def test_gen_cphm_time_positive(): - """Test that all time values are positive.""" - df = gen_cphm( - n=50, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0 - ) - assert (df["time"] > 0).all() - - -def test_gen_cphm_covariate_range(): - """Test that covariate values are within the specified range.""" - covar_max = 2.5 - df = gen_cphm( - n=100, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=covar_max - ) - assert (df["X0"] >= 0).all() - assert (df["X0"] <= covar_max).all() - - -def test_gen_cphm_seed_reproducibility(): - """Test that setting the same seed produces identical results.""" - df1 = gen_cphm( - n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 - ) - df2 = gen_cphm( - n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 - ) - pd.testing.assert_frame_equal(df1, df2) - - -def test_gen_cphm_different_seeds(): - """Test that different seeds produce different results.""" - df1 = gen_cphm( - n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=42 - ) - df2 = gen_cphm( - n=10, model_cens="uniform", cens_par=1.0, beta=0.5, covariate_range=2.0, seed=43 - ) - with pytest.raises(AssertionError): - pd.testing.assert_frame_equal(df1, df2) diff --git a/tests/test_export.py b/tests/test_export.py deleted file mode 100644 index 6028c77..0000000 --- a/tests/test_export.py +++ /dev/null @@ -1,80 +0,0 @@ -import pandas as pd -import pyreadr -import pytest - -from gen_surv.export import export_dataset -from gen_surv.validation import ChoiceError - - -@pytest.mark.parametrize( - "fmt, reader", - [ - ("csv", pd.read_csv), - ("feather", pd.read_feather), - ("ft", pd.read_feather), - ], -) -def test_export_dataset_formats(fmt, reader, tmp_path): - df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) - out = tmp_path / f"data.{fmt}" - export_dataset(df, out) - assert out.exists() - result = reader(out).astype(df.dtypes.to_dict()) - pd.testing.assert_frame_equal(result.reset_index(drop=True), df) - - -def test_export_dataset_json(monkeypatch, tmp_path): - df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) - out = tmp_path / "data.json" - - called = {} - - def fake_to_json(self, path, orient="table"): - called["args"] = (path, orient) - with open(path, "w", encoding="utf-8") as f: - f.write("{}") - - monkeypatch.setattr(pd.DataFrame, "to_json", fake_to_json) - export_dataset(df, out) - assert called["args"] == (out, "table") - assert out.exists() - - -def test_export_dataset_rds(monkeypatch, tmp_path): - df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) - out = tmp_path / "data.rds" - - captured = {} - - def fake_write_rds(path, data): - captured["path"] = path - captured["data"] = data - open(path, "wb").close() - - monkeypatch.setattr(pyreadr, "write_rds", fake_write_rds) - export_dataset(df, out) - assert out.exists() - pd.testing.assert_frame_equal(captured["data"], df.reset_index(drop=True)) - - -def test_export_dataset_explicit_fmt(monkeypatch, tmp_path): - df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) - out = tmp_path / "data.bin" - - called = {} - - def fake_to_json(self, path, orient="table"): - called["args"] = (path, orient) - with open(path, "w", encoding="utf-8") as f: - f.write("{}") - - monkeypatch.setattr(pd.DataFrame, "to_json", fake_to_json) - export_dataset(df, out, fmt="json") - assert called["args"] == (out, "table") - assert out.exists() - - -def test_export_dataset_invalid_format(tmp_path): - df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) - with pytest.raises(ChoiceError): - export_dataset(df, tmp_path / "data.xxx", fmt="txt") diff --git a/tests/test_integration_sksurv.py b/tests/test_integration_sksurv.py deleted file mode 100644 index 84697d7..0000000 --- a/tests/test_integration_sksurv.py +++ /dev/null @@ -1,13 +0,0 @@ -import pandas as pd -import pytest - -from gen_surv.integration import to_sksurv - - -def test_to_sksurv(): - # Optional integration test; skipped when scikit-survival is not installed. - pytest.importorskip("sksurv.util") - df = pd.DataFrame({"time": [1.0, 2.0], "status": [1, 0]}) - arr = to_sksurv(df) - assert arr.dtype.names == ("status", "time") - assert arr.shape[0] == 2 diff --git a/tests/test_interface.py b/tests/test_interface.py index 8be76b3..ad2dc3c 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,6 +1,5 @@ -import pytest - from gen_surv import generate +import pytest def test_generate_tdcm_runs(): diff --git a/tests/test_mixture.py b/tests/test_mixture.py deleted file mode 100644 index 2b1a4f9..0000000 --- a/tests/test_mixture.py +++ /dev/null @@ -1,83 +0,0 @@ -import pandas as pd -import pytest - -from gen_surv.mixture import cure_fraction_estimate, gen_mixture_cure - - -def test_gen_mixture_cure_runs(): - df = gen_mixture_cure(n=10, cure_fraction=0.3, seed=42) - assert isinstance(df, pd.DataFrame) - assert len(df) == 10 - assert {"time", "status", "cured"}.issubset(df.columns) - - -def test_cure_fraction_estimate_range(): - df = gen_mixture_cure(n=50, cure_fraction=0.3, seed=0) - est = cure_fraction_estimate(df) - assert 0 <= est <= 1 - - -def test_cure_fraction_estimate_empty_returns_zero(): - df = pd.DataFrame(columns=["time", "status"]) - est = cure_fraction_estimate(df) - assert est == 0.0 - - -def test_gen_mixture_cure_invalid_inputs(): - with pytest.raises(ValueError): - gen_mixture_cure(n=5, cure_fraction=1.5) - with pytest.raises(ValueError): - gen_mixture_cure(n=5, cure_fraction=0.2, baseline_hazard=0) - with pytest.raises(ValueError): - gen_mixture_cure(n=5, cure_fraction=0.2, covariate_dist="bad") - with pytest.raises(ValueError): - gen_mixture_cure(n=5, cure_fraction=0.2, model_cens="bad") - with pytest.raises(ValueError): - gen_mixture_cure( - n=5, - cure_fraction=0.2, - betas_survival=[0.1, 0.2], - betas_cure=[0.1], - ) - - -def test_gen_mixture_cure_max_time_cap(): - df = gen_mixture_cure( - n=50, - cure_fraction=0.3, - max_time=5.0, - model_cens="exponential", - cens_par=1.0, - seed=123, - ) - assert (df["time"] <= 5.0).all() - - -def test_cure_fraction_estimate_close_to_true(): - df = gen_mixture_cure(n=200, cure_fraction=0.4, seed=1) - est = cure_fraction_estimate(df) - assert pytest.approx(0.4, abs=0.15) == est - - -def test_gen_mixture_cure_covariate_distributions(): - for dist in ["uniform", "binary"]: - df = gen_mixture_cure(n=20, cure_fraction=0.3, covariate_dist=dist, seed=2) - assert {"time", "status", "cured", "X0", "X1"}.issubset(df.columns) - - -def test_gen_mixture_cure_no_max_time_allows_long_times(): - df = gen_mixture_cure( - n=100, - cure_fraction=0.5, - max_time=None, - model_cens="uniform", - cens_par=20.0, - seed=3, - ) - assert df["time"].max() > 10 - - -def test_cure_fraction_estimate_small_sample(): - df = gen_mixture_cure(n=3, cure_fraction=0.2, seed=4) - est = cure_fraction_estimate(df) - assert 0 <= est <= 1 diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py deleted file mode 100644 index 61b75d3..0000000 --- a/tests/test_piecewise.py +++ /dev/null @@ -1,104 +0,0 @@ -import pandas as pd -import pytest - -from gen_surv.piecewise import gen_piecewise_exponential - - -def test_gen_piecewise_exponential_runs(): - df = gen_piecewise_exponential( - n=10, breakpoints=[1.0], hazard_rates=[0.5, 1.0], seed=42 - ) - assert isinstance(df, pd.DataFrame) - assert len(df) == 10 - assert {"time", "status"}.issubset(df.columns) - - -def test_piecewise_invalid_lengths(): - with pytest.raises(ValueError): - gen_piecewise_exponential( - n=5, breakpoints=[1.0, 2.0], hazard_rates=[0.5], seed=42 - ) - - -def test_piecewise_invalid_hazard_and_breakpoints(): - with pytest.raises(ValueError): - gen_piecewise_exponential( - n=5, - breakpoints=[2.0, 1.0], - hazard_rates=[0.5, 1.0, 1.5], - seed=42, - ) - with pytest.raises(ValueError): - gen_piecewise_exponential( - n=5, - breakpoints=[1.0], - hazard_rates=[0.5, -1.0], - seed=42, - ) - - -def test_piecewise_covariate_distributions(): - for dist, params in [ - ("uniform", {"low": 0.0, "high": 1.0}), - ("binary", {"p": 0.7}), - ]: - df = gen_piecewise_exponential( - n=5, - breakpoints=[1.0], - hazard_rates=[0.2, 0.4], - covariate_dist=dist, - covariate_params=params, - seed=1, - ) - assert len(df) == 5 - assert {"X0", "X1"}.issubset(df.columns) - - -def test_piecewise_custom_betas_reproducible(): - df1 = gen_piecewise_exponential( - n=5, - breakpoints=[1.0], - hazard_rates=[0.1, 0.2], - betas=[0.5, -0.2], - seed=2, - ) - df2 = gen_piecewise_exponential( - n=5, - breakpoints=[1.0], - hazard_rates=[0.1, 0.2], - betas=[0.5, -0.2], - seed=2, - ) - pd.testing.assert_frame_equal(df1, df2) - - -def test_piecewise_invalid_covariate_dist(): - with pytest.raises(ValueError): - gen_piecewise_exponential( - n=5, - breakpoints=[1.0], - hazard_rates=[0.5, 1.0], - covariate_dist="unknown", - seed=1, - ) - - -def test_piecewise_invalid_censoring_model(): - with pytest.raises(ValueError): - gen_piecewise_exponential( - n=5, - breakpoints=[1.0], - hazard_rates=[0.5, 1.0], - model_cens="bad", - seed=1, - ) - - -def test_piecewise_negative_breakpoint(): - with pytest.raises(ValueError): - gen_piecewise_exponential( - n=5, - breakpoints=[-1.0], - hazard_rates=[0.5, 1.0], - seed=1, - ) diff --git a/tests/test_piecewise_functions.py b/tests/test_piecewise_functions.py deleted file mode 100644 index b1de35a..0000000 --- a/tests/test_piecewise_functions.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np - -from gen_surv.piecewise import piecewise_hazard_function, piecewise_survival_function - - -def test_piecewise_hazard_function_scalar_and_array(): - breakpoints = [1.0, 2.0] - hazard_rates = [0.5, 1.0, 1.5] - # Scalar values - assert piecewise_hazard_function(0.5, breakpoints, hazard_rates) == 0.5 - assert piecewise_hazard_function(1.5, breakpoints, hazard_rates) == 1.0 - assert piecewise_hazard_function(3.0, breakpoints, hazard_rates) == 1.5 - # Array values - arr = np.array([0.5, 1.5, 3.0]) - np.testing.assert_allclose( - piecewise_hazard_function(arr, breakpoints, hazard_rates), - np.array([0.5, 1.0, 1.5]), - ) - - -def test_piecewise_hazard_function_negative_time(): - """Hazard should be zero for negative times.""" - breakpoints = [1.0, 2.0] - hazard_rates = [0.5, 1.0, 1.5] - assert piecewise_hazard_function(-1.0, breakpoints, hazard_rates) == 0 - np.testing.assert_array_equal( - piecewise_hazard_function(np.array([-0.5, -2.0]), breakpoints, hazard_rates), - np.array([0.0, 0.0]), - ) - - -def test_piecewise_survival_function(): - breakpoints = [1.0, 2.0] - hazard_rates = [0.5, 1.0, 1.5] - # Known survival probabilities - expected = np.exp(-np.array([0.0, 0.25, 1.0, 3.0])) - times = np.array([0.0, 0.5, 1.5, 3.0]) - np.testing.assert_allclose( - piecewise_survival_function(times, breakpoints, hazard_rates), - expected, - ) - - -def test_piecewise_survival_function_scalar_and_negative(): - breakpoints = [1.0, 2.0] - hazard_rates = [0.5, 1.0, 1.5] - # Scalar output should be a float - val = piecewise_survival_function(1.5, breakpoints, hazard_rates) - assert isinstance(val, float) - assert np.isclose(val, np.exp(-1.0)) - # Negative times return survival of 1 - assert piecewise_survival_function(-2.0, breakpoints, hazard_rates) == 1 diff --git a/tests/test_sklearn_adapter.py b/tests/test_sklearn_adapter.py deleted file mode 100644 index f4b8c4a..0000000 --- a/tests/test_sklearn_adapter.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest - -from gen_surv.sklearn_adapter import GenSurvDataGenerator -from gen_surv.validation import ChoiceError - - -def test_sklearn_generator_dataframe(): - gen = GenSurvDataGenerator( - "cphm", - n=4, - beta=0.2, - covariate_range=1.0, - model_cens="uniform", - cens_par=1.0, - ) - df = gen.fit_transform() - assert len(df) == 4 - assert {"time", "status"}.issubset(df.columns) - - -def test_sklearn_generator_dict(): - gen = GenSurvDataGenerator( - "cphm", - return_type="dict", - n=3, - beta=0.5, - covariate_range=1.0, - model_cens="uniform", - cens_par=1.0, - ) - data = gen.transform() - assert isinstance(data, dict) - assert set(data.keys()) >= {"time", "status"} - assert len(data["time"]) == 3 - - -def test_sklearn_generator_invalid_return_type(): - with pytest.raises(ChoiceError): - GenSurvDataGenerator( - "cphm", - return_type="bad", - n=1, - beta=0.5, - covariate_range=1.0, - model_cens="uniform", - cens_par=1.0, - ) diff --git a/tests/test_summary.py b/tests/test_summary.py deleted file mode 100644 index cf63caf..0000000 --- a/tests/test_summary.py +++ /dev/null @@ -1,16 +0,0 @@ -from gen_surv import generate -from gen_surv.summary import summarize_survival_dataset - - -def test_summarize_survival_dataset_basic(): - df = generate( - model="cphm", - n=20, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=2.0, - ) - summary = summarize_survival_dataset(df, verbose=False) - assert isinstance(summary, dict) - assert "dataset_info" in summary diff --git a/tests/test_summary_extra.py b/tests/test_summary_extra.py deleted file mode 100644 index 380b818..0000000 --- a/tests/test_summary_extra.py +++ /dev/null @@ -1,104 +0,0 @@ -import pandas as pd -import pytest - -from gen_surv import generate -from gen_surv.summary import ( - _print_summary, - check_survival_data_quality, - compare_survival_datasets, -) - - -def test_check_survival_data_quality_fix_issues(): - df = pd.DataFrame( - { - "time": [1.0, -0.5, None, 1.0], - "status": [1, 2, 0, 1], - "id": [1, 2, 3, 1], - } - ) - fixed, issues = check_survival_data_quality( - df, - id_col="id", - max_time=2.0, - fix_issues=True, - ) - assert issues["modifications"]["rows_dropped"] == 2 - assert issues["modifications"]["values_fixed"] == 1 - assert len(fixed) == 2 - - -def test_check_survival_data_quality_no_fix(): - """Issues should be reported but data left unchanged when fix_issues=False.""" - df = pd.DataFrame({"time": [-1.0, 2.0], "status": [3, 1]}) - checked, issues = check_survival_data_quality(df, max_time=1.0, fix_issues=False) - # Data is returned unmodified - pd.testing.assert_frame_equal(df, checked) - assert issues["invalid_values"]["negative_time"] == 1 - assert issues["invalid_values"]["excessive_time"] == 1 - assert issues["invalid_values"]["invalid_status"] == 1 - - -def test_compare_survival_datasets_basic(): - ds1 = generate( - model="cphm", - n=5, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=1.0, - ) - ds2 = generate( - model="cphm", - n=5, - model_cens="uniform", - cens_par=1.0, - beta=1.0, - covariate_range=1.0, - ) - comparison = compare_survival_datasets({"A": ds1, "B": ds2}) - assert set(["A", "B"]).issubset(comparison.columns) - assert "n_subjects" in comparison.index - - -def test_compare_survival_datasets_with_covariates_and_empty_error(): - ds = generate( - model="cphm", - n=3, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=1.0, - ) - comparison = compare_survival_datasets({"only": ds}, covariate_cols=["X0"]) - assert "only" in comparison.columns - assert "X0_mean" in comparison.index - with pytest.raises(ValueError): - compare_survival_datasets({}) - - -def test_check_survival_data_quality_min_and_max(): - df = pd.DataFrame({"time": [-1.0, 3.0], "status": [1, 1]}) - fixed, issues = check_survival_data_quality( - df, min_time=0.0, max_time=2.0, fix_issues=True - ) - assert (fixed["time"] <= 2.0).all() - assert issues["modifications"]["values_fixed"] > 0 - - -def test_print_summary_with_issues(capsys): - summary = { - "dataset_info": {"n_subjects": 2, "n_unique_ids": 2, "n_covariates": 0}, - "event_info": {"n_events": 1, "n_censored": 1, "event_rate": 0.5}, - "time_info": {"min": 0.0, "max": 2.0, "mean": 1.0, "median": 1.0}, - "data_quality": { - "missing_time": 0, - "missing_status": 0, - "negative_time": 1, - "invalid_status": 0, - }, - "covariates": {}, - } - _print_summary(summary, "time", "status", None, []) - out = capsys.readouterr().out - assert "Issues detected" in out diff --git a/tests/test_summary_more.py b/tests/test_summary_more.py deleted file mode 100644 index 8894237..0000000 --- a/tests/test_summary_more.py +++ /dev/null @@ -1,60 +0,0 @@ -import pandas as pd -import pytest - -from gen_surv.summary import ( - _print_summary, - check_survival_data_quality, - summarize_survival_dataset, -) -from gen_surv.validation import ParameterError - - -def test_summarize_survival_dataset_errors(): - df = pd.DataFrame({"time": [1, 2], "status": [1, 0]}) - # Missing time column - with pytest.raises(ParameterError): - summarize_survival_dataset(df.drop(columns=["time"])) - # Missing ID column when specified - with pytest.raises(ParameterError): - summarize_survival_dataset(df, id_col="id") - # Missing covariate columns - with pytest.raises(ParameterError): - summarize_survival_dataset(df, covariate_cols=["bad"]) - - -def test_summarize_survival_dataset_verbose_output(capsys): - df = pd.DataFrame( - { - "time": [1.0, 2.0, 3.0], - "status": [1, 0, 1], - "id": [1, 2, 3], - "age": [30, 40, 50], - "group": ["A", "B", "A"], - } - ) - summary = summarize_survival_dataset( - df, id_col="id", covariate_cols=["age", "group"] - ) - _print_summary(summary, "time", "status", "id", ["age", "group"]) - captured = capsys.readouterr().out - assert "SURVIVAL DATASET SUMMARY" in captured - assert "age:" in captured - assert "Categorical" in captured - - -def test_check_survival_data_quality_duplicates_and_fix(): - df = pd.DataFrame( - { - "time": [1.0, -1.0, 2.0, 1.0], - "status": [1, 1, 0, 1], - "id": [1, 1, 2, 1], - } - ) - checked, issues = check_survival_data_quality(df, id_col="id", fix_issues=False) - assert issues["duplicates"]["duplicate_rows"] == 1 - assert issues["duplicates"]["duplicate_ids"] == 2 - fixed, issues_fixed = check_survival_data_quality( - df, id_col="id", max_time=2.0, fix_issues=True - ) - assert len(fixed) < len(df) - assert issues_fixed["modifications"]["rows_dropped"] > 0 diff --git a/tests/test_tdcm.py b/tests/test_tdcm.py index b393a3b..507b51f 100644 --- a/tests/test_tdcm.py +++ b/tests/test_tdcm.py @@ -1,16 +1,10 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gen_surv.tdcm import gen_tdcm - def test_gen_tdcm_shape(): - df = gen_tdcm( - n=50, - dist="weibull", - corr=0.5, - dist_par=[1, 2, 1, 2], - model_cens="uniform", - cens_par=1.0, - beta=[0.1, 0.2, 0.3], - lam=1.0, - ) + df = gen_tdcm(n=50, dist="weibull", corr=0.5, dist_par=[1, 2, 1, 2], + model_cens="uniform", cens_par=1.0, beta=[0.1, 0.2, 0.3], lam=1.0) assert df.shape[1] == 6 assert "tdcov" in df.columns diff --git a/tests/test_thmm.py b/tests/test_thmm.py index e91d2a1..b53b197 100644 --- a/tests/test_thmm.py +++ b/tests/test_thmm.py @@ -1,14 +1,11 @@ -from gen_surv.thmm import gen_thmm +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from gen_surv.thmm import gen_thmm def test_gen_thmm_shape(): - df = gen_thmm( - n=50, - model_cens="uniform", - cens_par=1.0, - beta=[0.1, 0.2, 0.3], - covariate_range=2.0, - rate=[0.5, 0.6, 0.7], - ) + df = gen_thmm(n=50, model_cens="uniform", cens_par=1.0, + beta=[0.1, 0.2, 0.3], covar=2.0, rate=[0.5, 0.6, 0.7]) assert df.shape[1] == 4 assert set(df["state"].unique()).issubset({1, 2, 3}) diff --git a/tests/test_validate.py b/tests/test_validate.py index 4827036..ffbbad0 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,16 +1,5 @@ -import numpy as np import pytest - -import gen_surv.validation as v -from gen_surv.validation import ( - ChoiceError, - ParameterError, - PositiveIntegerError, - PositiveValueError, - ensure_censoring_model, - ensure_positive, - ensure_positive_int, -) +import gen_surv.validate as v def test_validate_gen_cphm_inputs_valid(): @@ -19,7 +8,7 @@ def test_validate_gen_cphm_inputs_valid(): @pytest.mark.parametrize( - "n, model_cens, cens_par, covariate_range", + "n, model_cens, cens_par, covar", [ (0, "uniform", 0.5, 1.0), (1, "bad", 0.5, 1.0), @@ -27,10 +16,10 @@ def test_validate_gen_cphm_inputs_valid(): (1, "uniform", 0.5, -1.0), ], ) -def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covariate_range): +def test_validate_gen_cphm_inputs_invalid(n, model_cens, cens_par, covar): """Invalid parameter combinations should raise ValueError.""" with pytest.raises(ValueError): - v.validate_gen_cphm_inputs(n, model_cens, cens_par, covariate_range) + v.validate_gen_cphm_inputs(n, model_cens, cens_par, covar) def test_validate_dg_biv_inputs_invalid(): @@ -47,36 +36,11 @@ def test_validate_gen_cmm_inputs_invalid_beta_length(): "uniform", 0.5, [0.1, 0.2], - covariate_range=1.0, + covar=1.0, rate=[0.1] * 6, ) -@pytest.mark.parametrize( - "n, model_cens, cens_par, cov_range, rate", - [ - (0, "uniform", 0.5, 1.0, [0.1] * 6), - (1, "bad", 0.5, 1.0, [0.1] * 6), - (1, "uniform", 0.0, 1.0, [0.1] * 6), - (1, "uniform", 0.5, 0.0, [0.1] * 6), - (1, "uniform", 0.5, 1.0, [0.1] * 3), - ], -) -def test_validate_gen_cmm_inputs_other_invalid( - n, model_cens, cens_par, cov_range, rate -): - with pytest.raises(ValueError): - v.validate_gen_cmm_inputs( - n, model_cens, cens_par, [0.1, 0.2, 0.3], cov_range, rate - ) - - -def test_validate_gen_cmm_inputs_valid(): - v.validate_gen_cmm_inputs( - 1, "uniform", 1.0, [0.1, 0.2, 0.3], covariate_range=1.0, rate=[0.1] * 6 - ) - - def test_validate_gen_tdcm_inputs_invalid_lambda(): """Lambda <= 0 should raise a ValueError.""" with pytest.raises(ValueError): @@ -92,44 +56,6 @@ def test_validate_gen_tdcm_inputs_invalid_lambda(): ) -@pytest.mark.parametrize( - "dist,corr,dist_par", - [ - ("bad", 0.5, [1, 2]), - ("weibull", 0.0, [1, 2, 3, 4]), - ("weibull", 0.5, [1, 2, -1, 2]), - ("weibull", 0.5, [1, 2, 3]), - ("exponential", 2.0, [1, 1]), - ("exponential", 0.5, [1]), - ], -) -def test_validate_gen_tdcm_inputs_invalid_dist(dist, corr, dist_par): - with pytest.raises(ValueError): - v.validate_gen_tdcm_inputs( - 1, - dist, - corr, - dist_par, - "uniform", - 1.0, - beta=[0.1, 0.2, 0.3], - lam=1.0, - ) - - -def test_validate_gen_tdcm_inputs_valid(): - v.validate_gen_tdcm_inputs( - 1, - "weibull", - 0.5, - [1, 1, 1, 1], - "uniform", - 1.0, - beta=[0.1, 0.2, 0.3], - lam=1.0, - ) - - def test_validate_gen_aft_log_normal_inputs_valid(): """Valid parameters should not raise an error for AFT log-normal.""" v.validate_gen_aft_log_normal_inputs( @@ -141,132 +67,6 @@ def test_validate_gen_aft_log_normal_inputs_valid(): ) -@pytest.mark.parametrize( - "n,beta,sigma,model_cens,cens_par", - [ - (0, [0.1], 1.0, "uniform", 1.0), - (1, "bad", 1.0, "uniform", 1.0), - (1, [0.1], 0.0, "uniform", 1.0), - (1, [0.1], 1.0, "bad", 1.0), - (1, [0.1], 1.0, "uniform", 0.0), - ], -) -def test_validate_gen_aft_log_normal_inputs_invalid( - n, beta, sigma, model_cens, cens_par -): - with pytest.raises(ValueError): - v.validate_gen_aft_log_normal_inputs(n, beta, sigma, model_cens, cens_par) - - def test_validate_dg_biv_inputs_valid_weibull(): """Valid parameters for a Weibull distribution should pass.""" v.validate_dg_biv_inputs(5, "weibull", 0.1, [1.0, 1.0, 1.0, 1.0]) - - -def test_validate_dg_biv_inputs_invalid_corr_and_params(): - with pytest.raises(ValueError): - v.validate_dg_biv_inputs(1, "exponential", -2.0, [1.0, 1.0]) - with pytest.raises(ValueError): - v.validate_dg_biv_inputs(1, "exponential", 0.5, [1.0]) - with pytest.raises(ValueError): - v.validate_dg_biv_inputs(1, "weibull", 0.5, [1.0, 1.0]) - - -def test_validate_gen_aft_weibull_inputs_and_log_logistic(): - with pytest.raises(ValueError): - v.validate_gen_aft_weibull_inputs(0, [0.1], 1.0, 1.0, "uniform", 1.0) - with pytest.raises(ValueError): - v.validate_gen_aft_log_logistic_inputs(1, [0.1], -1.0, 1.0, "uniform", 1.0) - - -@pytest.mark.parametrize( - "shape,scale", - [(-1.0, 1.0), (1.0, -1.0)], -) -def test_validate_gen_aft_weibull_invalid_params(shape, scale): - with pytest.raises(ValueError): - v.validate_gen_aft_weibull_inputs(1, [0.1], shape, scale, "uniform", 1.0) - - -def test_validate_gen_aft_weibull_valid(): - v.validate_gen_aft_weibull_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0) - - -def test_validate_gen_aft_log_logistic_valid(): - v.validate_gen_aft_log_logistic_inputs(1, [0.1], 1.0, 1.0, "uniform", 1.0) - - -def test_positive_sequence_nan_inf(): - with pytest.raises(v.PositiveSequenceError): - v.ensure_positive_sequence([1.0, float("nan")], "x") - with pytest.raises(v.PositiveSequenceError): - v.ensure_positive_sequence([1.0, float("inf")], "x") - - -def test_numeric_sequence_rejects_bool(): - with pytest.raises(v.NumericSequenceError): - v.ensure_numeric_sequence([1, True], "x") - - -def test_validate_competing_risks_inputs(): - with pytest.raises(ValueError): - v.validate_competing_risks_inputs(1, 2, [0.1], None, "uniform", 1.0) - v.validate_competing_risks_inputs(1, 1, [0.5], [[0.1]], "uniform", 0.5) - - -@pytest.mark.parametrize( - "n,model_cens,cens_par,beta,cov_range,rate", - [ - (0, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), - (1, "bad", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), - (1, "uniform", 0.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]), - (1, "uniform", 1.0, [0.1, 0.2], 1.0, [0.1, 0.2, 0.3]), - (1, "uniform", 1.0, [0.1, 0.2, 0.3], 0.0, [0.1, 0.2, 0.3]), - (1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1]), - ], -) -def test_validate_gen_thmm_inputs_invalid( - n, model_cens, cens_par, beta, cov_range, rate -): - with pytest.raises(ValueError): - v.validate_gen_thmm_inputs(n, model_cens, cens_par, beta, cov_range, rate) - - -def test_validate_gen_thmm_inputs_valid(): - v.validate_gen_thmm_inputs(1, "uniform", 1.0, [0.1, 0.2, 0.3], 1.0, [0.1, 0.2, 0.3]) - - -def test_positive_integer_error(): - with pytest.raises(PositiveIntegerError): - ensure_positive_int(-1, "n") - - -def test_ensure_positive_int_accepts_numpy_and_rejects_bool(): - ensure_positive_int(np.int64(5), "n") - with pytest.raises(PositiveIntegerError): - ensure_positive_int(True, "n") - - -def test_ensure_positive_accepts_numpy_and_rejects_bool(): - ensure_positive(np.float64(0.1), "val") - with pytest.raises(PositiveValueError): - ensure_positive(True, "val") - - -def test_censoring_model_choice_error(): - with pytest.raises(ChoiceError): - ensure_censoring_model("bad") - - -def test_parameter_error_from_validator(): - with pytest.raises(ParameterError): - v.validate_gen_tdcm_inputs( - 1, - "weibull", - 0.0, - [1, 2, 3, 4], - "uniform", - 1.0, - beta=[0.1, 0.2, 0.3], - lam=1.0, - ) diff --git a/tests/test_version.py b/tests/test_version.py index aff52a9..6fb4a02 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,5 +1,4 @@ from importlib.metadata import version - from gen_surv import __version__ diff --git a/tests/test_visualization.py b/tests/test_visualization.py deleted file mode 100644 index b3ac616..0000000 --- a/tests/test_visualization.py +++ /dev/null @@ -1,173 +0,0 @@ -import pandas as pd -import pytest -import typer - -from gen_surv import generate -from gen_surv.cli import visualize -from gen_surv.visualization import ( - describe_survival, - plot_covariate_effect, - plot_hazard_comparison, - plot_survival_curve, -) - - -def test_plot_survival_curve_runs(): - df = generate( - model="cphm", - n=10, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=2.0, - ) - fig, ax = plot_survival_curve(df) - assert fig is not None - assert ax is not None - - -def test_plot_hazard_comparison_runs(): - df1 = generate( - model="cphm", - n=5, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=1.0, - ) - df2 = generate( - model="aft_weibull", - n=5, - beta=[0.5], - shape=1.5, - scale=2.0, - model_cens="uniform", - cens_par=1.0, - ) - models = {"cphm": df1, "aft_weibull": df2} - fig, ax = plot_hazard_comparison(models) - assert fig is not None - assert ax is not None - - -def test_plot_covariate_effect_runs(): - df = generate( - model="cphm", - n=10, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=2.0, - ) - fig, ax = plot_covariate_effect(df, covariate_col="X0", n_groups=2) - assert fig is not None - assert ax is not None - - -def test_describe_survival_summary(): - df = generate( - model="cphm", - n=10, - model_cens="uniform", - cens_par=1.0, - beta=0.5, - covariate_range=2.0, - ) - summary = describe_survival(df) - expected_metrics = [ - "Total Observations", - "Number of Events", - "Number Censored", - "Event Rate", - "Median Survival Time", - "Min Time", - "Max Time", - "Mean Time", - ] - assert list(summary["Metric"]) == expected_metrics - assert summary.shape[0] == len(expected_metrics) - - -def test_cli_visualize(tmp_path, capsys): - df = pd.DataFrame({"time": [1, 2, 3], "status": [1, 0, 1]}) - csv_path = tmp_path / "d.csv" - df.to_csv(csv_path, index=False) - out_file = tmp_path / "out.png" - visualize( - str(csv_path), - time_col="time", - status_col="status", - group_col=None, - output=str(out_file), - ) - assert out_file.exists() - captured = capsys.readouterr() - assert "Plot saved to" in captured.out - - -def test_cli_visualize_missing_column(tmp_path, capsys): - df = pd.DataFrame({"time": [1, 2], "event": [1, 0]}) - csv_path = tmp_path / "bad.csv" - df.to_csv(csv_path, index=False) - with pytest.raises(typer.Exit): - visualize( - str(csv_path), - time_col="time", - status_col="status", - group_col=None, - ) - captured = capsys.readouterr() - assert "Status column 'status' not found in data" in captured.out - - -def test_cli_visualize_missing_time(tmp_path, capsys): - df = pd.DataFrame({"t": [1, 2], "status": [1, 0]}) - path = tmp_path / "d.csv" - df.to_csv(path, index=False) - with pytest.raises(typer.Exit): - visualize(str(path), time_col="time", status_col="status") - captured = capsys.readouterr() - assert "Time column 'time' not found in data" in captured.out - - -def test_cli_visualize_missing_group(tmp_path, capsys): - df = pd.DataFrame({"time": [1], "status": [1], "x": [0]}) - path = tmp_path / "d2.csv" - df.to_csv(path, index=False) - with pytest.raises(typer.Exit): - visualize(str(path), time_col="time", status_col="status", group_col="group") - captured = capsys.readouterr() - assert "Group column 'group' not found in data" in captured.out - - -def test_cli_visualize_import_error(monkeypatch, tmp_path, capsys): - """visualize exits when matplotlib is missing.""" - import builtins - - real_import = builtins.__import__ - - def fake_import(name, *args, **kwargs): - if name.startswith("matplotlib"): # simulate missing dependency - raise ImportError("no matplot") - return real_import(name, *args, **kwargs) - - monkeypatch.setattr(builtins, "__import__", fake_import) - csv_path = tmp_path / "d.csv" - pd.DataFrame({"time": [1], "status": [1]}).to_csv(csv_path, index=False) - with pytest.raises(typer.Exit): - visualize(str(csv_path)) - captured = capsys.readouterr() - assert "Visualization requires matplotlib" in captured.out - - -def test_cli_visualize_read_error(monkeypatch, tmp_path, capsys): - """visualize handles CSV read failures gracefully.""" - monkeypatch.setattr( - "pandas.read_csv", lambda *a, **k: (_ for _ in ()).throw(Exception("boom")) - ) - csv_path = tmp_path / "x.csv" - csv_path.write_text("time,status\n1,1\n") - with pytest.raises(typer.Exit): - visualize(str(csv_path)) - captured = capsys.readouterr() - assert "Error loading CSV file" in captured.out