Skip to content

Commit 5a59722

Browse files
authored
Merge pull request #79 from cai4cai/joss-revisions-main
Joss revisions main
2 parents 996c019 + 9362e65 commit 5a59722

13 files changed

Lines changed: 338 additions & 225 deletions

File tree

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,21 @@ pip install git+https://github.com/cai4cai/torchsparsegradutils
9191
For full functionality, install optional dependencies:
9292

9393
```bash
94-
# For CuPy sparse solver support (GPU acceleration)
95-
pip install cupy-cuda12x # Replace with your CUDA version
94+
# For CuPy sparse solver support (GPU acceleration, requires CUDA 12.x)
95+
pip install torchsparsegradutils[cupy]
9696

9797
# For JAX sparse solver support
98-
pip install "jax[cpu]" # CPU version
99-
pip install "jax[cuda12]" # GPU version (replace with your CUDA version)
98+
pip install torchsparsegradutils[jax]
99+
100+
# Install all optional dependencies
101+
pip install torchsparsegradutils[all]
100102

101103
# For benchmarking and testing
102104
pip install scipy matplotlib pandas tqdm pytest
103105
```
104106

107+
> **Note:** The CuPy extra installs `cupy-cuda12x>=13.0`. If you are using a different CUDA version, install the appropriate CuPy package manually (e.g. `pip install cupy-cuda11x`).
108+
105109
### Requirements
106110

107111
- **Python**: ≥ 3.10
@@ -118,7 +122,7 @@ Our comprehensive benchmark suite demonstrates significant performance improveme
118122

119123
![Sparse Triangular Solve Suite Performance (int32/float32 COO)](torchsparsegradutils/benchmarks/benchmark_visualizations/triangular_solve_suitesparse_performance_int32_float32_coo.png)
120124

121-
![Sparse Genertic Solve Suite Performance (int32/float32 COO)](torchsparsegradutils/benchmarks/benchmark_visualizations/sparse_solve_suite_performance_int32_float32_coo.png)
125+
![Sparse Generic Solve Suite Performance (int32/float32 COO)](torchsparsegradutils/benchmarks/benchmark_visualizations/sparse_solve_suite_performance_int32_float32_coo.png)
122126

123127
## 🚀 Quick Start
124128

docs/source/benchmarks.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ Matrix: ``Rothberg/cfd2`` (123,440 × 123,440, nnz 3,085,406). Right‑hand side
341341

342342
**Conclusions:**
343343

344-
1. The dense PyTorch solver ``torch.linalg.solve`` fails due to out-of-memory (OOM) errors before the foward pass due to failure of creating a dense tensor which would occupy 57GB of CUDA memory.
344+
1. The dense PyTorch solver ``torch.linalg.solve`` fails due to out-of-memory (OOM) errors before the forward pass due to failure of creating a dense tensor which would occupy 57GB of CUDA memory.
345345
2. ``torch.sparse_csr`` with ``float32`` and ``int32`` indices is the most memory efficient format for both forward and backward passes.
346346
3. Similar to ``tsgu.sparse_mm``, the ``int32`` indices for ``torch.sparse_coo`` format uses marginally less memory than ``int64`` despite ``A.indices()`` returning ``int64`` indices.
347347
4. All CuPy and JAX solvers use the same amount of memory on the forward and backward pass.

docs/source/installation.rst

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,20 @@ For additional functionality, you can install optional dependencies:
1919

2020
.. code-block:: bash
2121
22-
# Install with JAX and CuPy support
23-
pip install torchsparsegradutils[extras]
22+
# Install with CuPy support (GPU acceleration, requires CUDA 12.x)
23+
pip install torchsparsegradutils[cupy]
2424
25-
# Or install them separately
26-
pip install jax cupy
25+
# Install with JAX support
26+
pip install torchsparsegradutils[jax]
27+
28+
# Install all optional dependencies
29+
pip install torchsparsegradutils[all]
30+
31+
.. note::
32+
33+
The CuPy extra installs ``cupy-cuda12x>=13.0``. If you are using a different
34+
CUDA version, install the appropriate CuPy package manually
35+
(e.g. ``pip install cupy-cuda11x``).
2736

2837
Requirements
2938
------------
@@ -37,8 +46,8 @@ Core Requirements
3746
Optional Requirements
3847
~~~~~~~~~~~~~~~~~~~~~
3948

40-
- JAX (for JAX backend integration)
41-
- CuPy (for CuPy backend integration)
49+
- JAX (for JAX backend integration): ``pip install torchsparsegradutils[jax]``
50+
- CuPy >= 13.0 (for CuPy backend integration): ``pip install torchsparsegradutils[cupy]``
4251

4352
Development Installation
4453
------------------------
@@ -55,7 +64,7 @@ To install development dependencies:
5564

5665
.. code-block:: bash
5766
58-
pip install -e .[extras]
67+
pip install -e .[all]
5968
pip install -r requirements-ci.txt
6069
6170
Verification
@@ -90,7 +99,7 @@ You can also use the package in a Docker container. Here's a simple Dockerfile:
9099
91100
FROM pytorch/pytorch:latest
92101
93-
RUN pip install torchsparsegradutils[extras]
102+
RUN pip install torchsparsegradutils[all]
94103
95104
# Your application code
96105
COPY . /app

paper/jats/paper.jats

Lines changed: 112 additions & 79 deletions
Large diffs are not rendered by default.

paper/paper.bib

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,13 @@ @inproceedings{gpytorch
5959
booktitle = {Advances in Neural Information Processing Systems},
6060
volume = {31},
6161
year = {2018},
62-
doi = {10.5555/3327757.3327857},
6362
url = {https://arxiv.org/abs/1809.11165}
6463
}
6564

6665
@misc{flaport2020sparse,
6766
title = {Solving sparse linear systems in PyTorch},
6867
author = {Laporte, Floris},
6968
year = {2020},
70-
howpublished = {\url{https://blog.flaport.net/solving-sparse-linear-systems-in-pytorch.html}},
69+
url = {https://blog.flaport.net/solving-sparse-linear-systems-in-pytorch.html},
7170
note = {Accessed: 2025-08-22}
7271
}

0 commit comments

Comments
 (0)