You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For full functionality, install optional dependencies:
92
92
93
93
```bash
94
-
# For CuPy sparse solver support (GPU acceleration)
95
-
pip install cupy-cuda12x # Replace with your CUDA version
94
+
# For CuPy sparse solver support (GPU acceleration, requires CUDA 12.x)
95
+
pip install torchsparsegradutils[cupy]
96
96
97
97
# For JAX sparse solver support
98
-
pip install "jax[cpu]"# CPU version
99
-
pip install "jax[cuda12]"# GPU version (replace with your CUDA version)
98
+
pip install torchsparsegradutils[jax]
99
+
100
+
# Install all optional dependencies
101
+
pip install torchsparsegradutils[all]
100
102
101
103
# For benchmarking and testing
102
104
pip install scipy matplotlib pandas tqdm pytest
103
105
```
104
106
107
+
> **Note:** The CuPy extra installs `cupy-cuda12x>=13.0`. If you are using a different CUDA version, install the appropriate CuPy package manually (e.g. `pip install cupy-cuda11x`).
1. The dense PyTorch solver ``torch.linalg.solve`` fails due to out-of-memory (OOM) errors before the foward pass due to failure of creating a dense tensor which would occupy 57GB of CUDA memory.
344
+
1. The dense PyTorch solver ``torch.linalg.solve`` fails due to out-of-memory (OOM) errors before the forward pass due to failure of creating a dense tensor which would occupy 57GB of CUDA memory.
345
345
2. ``torch.sparse_csr`` with ``float32`` and ``int32`` indices is the most memory efficient format for both forward and backward passes.
346
346
3. Similar to ``tsgu.sparse_mm``, the ``int32`` indices for ``torch.sparse_coo`` format uses marginally less memory than ``int64`` despite ``A.indices()`` returning ``int64`` indices.
347
347
4. All CuPy and JAX solvers use the same amount of memory on the forward and backward pass.
0 commit comments