You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Marin uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as a core library.
46
-
Install Python dependencies for CUDA 12.x via uv:
26
+
Marin uses [JAX](https://docs.jax.dev/en/latest/index.html) as a core library. The `gpu`
27
+
extra installs the CUDA 13 JAX runtime, including CUDA, cuDNN, and NCCL Python wheels:
47
28
48
29
```bash
49
30
uv sync --extra=gpu
50
31
```
51
32
52
-
See [JAX's installation guide](https://jax.readthedocs.io/en/latest/installation.html) for more options.
33
+
If you install a local CUDA toolkit for custom kernels, use CUDA 13 and keep older CUDA libraries
34
+
out of `LD_LIBRARY_PATH` so they do not override the JAX wheel libraries.
35
+
36
+
See [JAX's installation guide](https://docs.jax.dev/en/latest/installation.html) for more options.
53
37
54
38
!!! tip
55
39
If you are using a DGX Spark or similar machine with unified memory, you may need to dramatically reduce the memory that XLA preallocates for itself. You can do this by setting the `XLA_PYTHON_CLIENT_MEM_FRACTION` variable, to something like 0.5:
0 commit comments