Skip to content

Commit 35fa858

Browse files
committed
Tree for 0.1.0
0 parents  commit 35fa858

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+10707
-0
lines changed

.gitignore

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Temporary and binary files
2+
*~
3+
*.py[cod]
4+
*.so
5+
*.cfg
6+
!.isort.cfg
7+
!setup.cfg
8+
*.orig
9+
*.log
10+
*.pot
11+
__pycache__/*
12+
.cache/*
13+
.*.swp
14+
**/.ipynb_checkpoints/*
15+
*/.ipynb_checkpoints/*
16+
**/lightning_logs/*
17+
**/checkpoints/*
18+
.DS_Store
19+
20+
# Project files
21+
.ropeproject
22+
.project
23+
.pydevproject
24+
.settings
25+
.idea
26+
.vscode
27+
tags
28+
29+
# Package files
30+
*.egg
31+
*.eggs/
32+
.installed.cfg
33+
*.egg-info
34+
35+
# Unittest and coverage
36+
htmlcov/*
37+
.coverage
38+
.coverage.*
39+
.tox
40+
junit*.xml
41+
coverage.xml
42+
.pytest_cache/
43+
44+
# Build and docs folder/files
45+
build/*
46+
dist/*
47+
sdist/*
48+
docs/api/*
49+
docs/_rst/*
50+
docs/_build/*
51+
docs/_autosummary/*
52+
cover/*
53+
MANIFEST
54+
55+
# Per-project virtualenvs
56+
.venv*/
57+
.conda*/
58+
.python-version
59+
60+
61+
# Data
62+
*.gz
63+
*-ubyte
64+
MNIST/**
65+

LICENSE.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright (c) 2023 Novartis Pharmaceuticals Corporation
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Survival analysis made easy
2+
3+
`torchsurv` is a statistical package that serves as a companion tool for survival modeling on PyTorch. Its core functionalities include the computation of common survival models’ log-likelihood and predictive performance metrics. `torchsurv` requires minimal input specifications and does not impose parameter forms allowing to rapidly define, evaluate, and optimize deep survival models.
4+
5+
## TL;DR
6+
7+
Our idea is to **keep things simple**. You are free to use any model architecture you want! Our code has 100% PyTorch backend and behaves like any other functions (losses or metrics) you may be familiar with.
8+
9+
Our functions are designed to support you, not to make you jump through hoops. Here's a pseudo code illustrating how easy is it to use `torchsurv` to fit and evaluate a Cox proportional hazards model:
10+
11+
```python
12+
from torchsurv.loss import cox
13+
from torchsurv.metrics.cindex import ConcordanceIndex
14+
15+
# Pseudo training loop
16+
for data in dataloader:
17+
x, event, time = data
18+
estimate = model(x) # shape = torch.Size([64, 1]), if batch size is 64
19+
loss = cox.neg_partial_log_likelihood(estimate, event, time)
20+
loss.backward() # native torch backend
21+
22+
# You can check model performance using our evaluation metrics, e.g, the concordance index with
23+
cindex = ConcordanceIndex()
24+
cindex(estimate, event, time)
25+
cindex.p_value(method="noether", alternative="two_sided")
26+
27+
# You can even compare the metrics between two models (e.g., vs. model B)
28+
cindex.compare(cindexB)
29+
```
30+
31+
## Installation
32+
33+
First, install the package:
34+
35+
```bash
36+
pip install torchsurv
37+
```
38+
or for local installation (from package root)
39+
40+
```bash
41+
pip install -e .
42+
```
43+
44+
If you use Conda, you can install requirements into a conda environment
45+
using the `environment.yml` file included in the `dev` subfolder of the source repository.
46+
47+
## Getting started
48+
49+
We recommend starting with the [introductory guide](notebooks/introduction), where you'll find an overview of the package's functionalities.
50+
51+
## Usage
52+
53+
### Survival data
54+
55+
We simulate a random batch of 64 sujects. Each subject is associated with a binary event status (= ```True``` if event occured), a time-to-event or censoring and 16 covariates.
56+
57+
```python
58+
>>> import torch
59+
>>> _ = torch.manual_seed(52)
60+
>>> n = 64
61+
>>> x = torch.randn((n, 16))
62+
>>> event = torch.randint(low=0, high=2, size=(n,)).bool()
63+
>>> time = torch.randint(low=1, high=100, size=(n,)).float()
64+
```
65+
66+
### Cox proportional hazards model
67+
68+
The user is expected to have defined a model that outputs the estimated *log relative hazard* for each subject. For illustrative purposes, we define a simple linear model that generates a linear combination of the covariates.
69+
70+
```python
71+
>>> from torch import nn
72+
>>> model_cox = nn.Sequential(nn.Linear(16, 1))
73+
>>> log_hz = model_cox(x)
74+
>>> print(log_hz.shape)
75+
torch.Size([64, 1])
76+
```
77+
78+
Given the estimated log relative hazard and the survival data, we calculate the current loss for the batch with:
79+
80+
```python
81+
>>> from torchsurv.loss.cox import neg_partial_log_likelihood
82+
>>> loss = neg_partial_log_likelihood(log_hz, event, time)
83+
>>> print(loss)
84+
tensor(4.1723, grad_fn=<MeanBackward0>)
85+
```
86+
We obtain the concordance index for this batch with:
87+
88+
```python
89+
>>> from torchsurv.metrics.cindex import ConcordanceIndex
90+
>>> with torch.no_grad():
91+
>>> log_hz = model_cox(x)
92+
>>> cindex = ConcordanceIndex()
93+
>>> print(cindex(log_hz, event, time))
94+
tensor(0.4872)
95+
```
96+
97+
We obtain the Area Under the Receiver Operating Characteristic Curve (AUC) at a new time $t = 50$ for this batch with:
98+
99+
```python
100+
>>> from torchsurv.metrics.auc import Auc
101+
>>> new_time = torch.tensor(50.)
102+
>>> auc = Auc()
103+
>>> print(auc(log_hz, event, time, new_time=50))
104+
tensor([0.4737])
105+
```
106+
107+
### Weibull accelerated failure time (AFT) model
108+
109+
The user is expected to have defined a model that outputs for each subject the estimated *log scale* and optionally the *log shape* of the Weibull distribution that the event density follows. In case the model has a single output, `torchsurv` assume that the shape is equal to 1, resulting in the event density to be an exponential distribution solely parametrized by the scale.
110+
111+
For illustrative purposes, we define a simple linear model that estimate two linear combinations of the covariates (log scale and log shape parameters).
112+
```python
113+
>>> from torch import nn
114+
>>> model = nn.Sequential(nn.Linear(16, 2))
115+
>>> log_params = model(x)
116+
>>> print(log_params.shape)
117+
torch.Size([64, 2])
118+
```
119+
120+
Given the estimated log scale and log shape and the survival data, we calculate the current loss for the batch with:
121+
122+
```python
123+
>>> from torchsurv.loss.weibull import neg_log_likelihood
124+
>>> loss = neg_log_likelihood(log_params, event, time)
125+
>>> print(loss)
126+
tensor(731.9636, grad_fn=<MeanBackward0>)
127+
```
128+
129+
To evaluate the predictive performance of the model, we calculate subject-specific log hazard and survival function evaluated at all times with:
130+
131+
```python
132+
>>> from torchsurv.loss.weibull import log_hazard
133+
>>> from torchsurv.loss.weibull import survival_function
134+
>>> with torch.no_grad():
135+
>>> log_params = model(x)
136+
>>> log_hz = log_hazard(log_params, time)
137+
>>> print(log_hz.shape)
138+
torch.Size([64, 64])
139+
>>> surv = survival_function(log_params, time)
140+
>>> print(surv.shape)
141+
torch.Size([64, 64])
142+
```
143+
144+
We obtain the concordance index for this batch with:
145+
146+
```python
147+
>>> from torchsurv.metrics.cindex import ConcordanceIndex
148+
>>> cindex = ConcordanceIndex()
149+
>>> print(cindex(log_hz, event, time))
150+
tensor(0.4062)
151+
```
152+
153+
We obtain the AUC at a new time $t =50$ for this batch with:
154+
155+
```python
156+
>>> from torchsurv.metrics.auc import Auc
157+
>>> new_time = torch.tensor(50.)
158+
>>> log_hz_t = log_hazard(log_params, time=new_time)
159+
>>> auc = Auc()
160+
>>> print(auc(log_hz_t, event, time, new_time=new_time))
161+
tensor([0.3509])
162+
```
163+
164+
We obtain the integrated brier-score with:
165+
166+
```python
167+
>>> from torchsurv.metrics.brier_score import BrierScore
168+
>>> brier_score = BrierScore()
169+
>>> bs = brier_score(surv, event, time)
170+
>>> print(brier_score.integral())
171+
tensor(0.4447)
172+
```
173+
174+
## Contributing
175+
176+
We value contributions from the community to enhance and improve this project. If you'd like to contribute, please consider the following:
177+
178+
1. Create Issues: If you encounter bugs, have feature requests, or want to suggest improvements, please create an issue in the GitHub repository. Make sure to provide detailed information about the problem, including code for reproducibility, or enhancement you're proposing.
179+
180+
2. Fork and Pull Requests: If you're willing to address an existing issue or contribute a new feature, fork the repository, create a new branch, make your changes, and then submit a pull request. Please ensure your code follows our coding conventions and include tests for any new functionality.
181+
182+
By contributing to this project, you agree to license your contributions under the same license as this project.
183+
184+
185+
## Contact
186+
187+
If you have any questions, suggestions, or feedback, feel free to reach out to [us](AUTHORS).
188+
189+
## Cite
190+
191+
If you use this project in academic work or publications, we appreciate citing it using the following BibTeX entry:
192+
193+
```bibtex
194+
@misc{projectname,
195+
title={Project Name},
196+
author={Your Name},
197+
year={2024},
198+
publisher={GitHub},
199+
howpublished={\url{https://github.com/your_username/your_repo}},
200+
}
201+
```
202+

dev/build-docs.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
# build & preview the documentation locally
3+
4+
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
5+
6+
set -e
7+
cd "${DIR}/../docs"
8+
make html
9+
cd _build/html
10+
11+
if [ "$1" == "serve" ]; then
12+
python -m http.server --bind 127.0.0.1 8000
13+
fi

dev/codeqc.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
3+
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
4+
5+
set -e
6+
7+
cd "${DIR}/.."
8+
9+
if [ "$1" == "check" ]; then
10+
CHECK="--check"
11+
else
12+
CHECK=""
13+
fi
14+
15+
export PYTHONPATH=${DIR}/../src
16+
17+
isort --profile black ${CHECK} src tests
18+
black ${CHECK} src tests
19+
# pylint src tests

dev/environment.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: torchsurv
2+
channels:
3+
- conda-forge
4+
- pytorch
5+
dependencies:
6+
- python=3.10
7+
- build
8+
- pep517
9+
- numpy
10+
- pandas
11+
- pytorch
12+
- torchvision
13+
- lightning
14+
- loguru
15+
- scikit-survival
16+
- lifelines
17+
- tox
18+
- black
19+
- pylint
20+
- mypy
21+
- sphinx
22+
- sphinx-book-theme
23+
- sphinxcontrib-bibtex
24+
- myst-parser
25+
- nbsphinx
26+
- ipython

dev/run-doctests.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
# Run all unit tests
3+
4+
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
5+
6+
set -e
7+
8+
cd "${DIR}/.."
9+
10+
export PYTHONPATH="${DIR}/../src"
11+
12+
files=$(find src -type f -name "*.py" -exec grep -l 'if __name__ == "__main__"' {} +)
13+
14+
for f in $files; do
15+
module=$(echo $f | sed 's/\//./g' | sed 's/\.py//g')
16+
echo "Running $f -> $module"
17+
python -m $module
18+
done

dev/run-unittests.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
# Run all unit tests
3+
4+
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
5+
6+
set -e
7+
8+
cd "${DIR}/.."
9+
10+
export PYTHONPATH="${DIR}/../src"
11+
12+
python -m unittest discover -s tests $@

docs/AUTHORS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Contributors
2+
============
3+
4+
* Thibaud Coroller <[email protected]>
5+
* Mélodie Monod <[email protected]>
6+
* Peter Krusche <[email protected]>
7+
* Qian Cao <[email protected]>
8+

0 commit comments

Comments
 (0)