Skip to content

Commit c938b3d

Browse files
committed
Move GPU JAX runtime to CUDA 13
1 parent 592d930 commit c938b3d

6 files changed

Lines changed: 603 additions & 478 deletions

File tree

docs/tutorials/local-gpu.md

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,32 @@ Similar steps will let you run Marin on a cloud GPU environment under Iris (the
88

99
Make sure you've followed the [installation guide](installation.md) to do the basic installation.
1010

11-
In addition to the prerequisites from the basic installation, we have GPU-specific dependencies:
11+
In addition to the prerequisites from the basic installation, we have one GPU-specific system dependency:
1212

13-
- CUDA Toolkit (version 12.1 or higher)
14-
- cuDNN (version 9.1 or higher)
13+
- NVIDIA driver 580 or newer
1514

1615
We assume you are running Ubuntu 24.04.
1716

18-
## CUDA installation
17+
## NVIDIA driver and runtime
1918

20-
Install CUDA 12.9.0:
19+
Install an NVIDIA driver that supports CUDA 13. Verify that the driver is at least 580 and that
20+
`nvidia-smi` reports CUDA 13.x:
2121

2222
```bash
23-
wget https://developer.download.nvidia.com/compute/cuda/12.9.0/local_installers/cuda_12.9.0_575.51.03_linux.run
24-
sudo sh cuda_12.9.0_575.51.03_linux.run
23+
nvidia-smi
2524
```
2625

27-
Install cuDNN 9.9.0 (Instructions from [NVIDIA's cuDNN download page](https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=24.04&target_type=deb_local)):
28-
29-
```bash
30-
wget https://developer.download.nvidia.com/compute/cudnn/9.10.0/local_installers/cudnn-local-repo-ubuntu2404-9.10.0_1.0-1_amd64.deb
31-
sudo dpkg -i cudnn-local-repo-ubuntu2404-9.10.0_1.0-1_amd64.deb
32-
sudo cp /var/cudnn-local-repo-ubuntu2404-9.10.0/cudnn-*-keyring.gpg /usr/share/keyrings/
33-
sudo apt-get update
34-
sudo apt-get -y install cudnn
35-
sudo apt-get -y install cudnn-cuda-12
36-
sudo apt-get -y install nvidia-cuda-toolkit
37-
```
38-
39-
Verify your setup by checking the CUDA version:
40-
41-
```bash
42-
nvcc --version
43-
```
44-
45-
Marin uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as a core library.
46-
Install Python dependencies for CUDA 12.x via uv:
26+
Marin uses [JAX](https://docs.jax.dev/en/latest/index.html) as a core library. The `gpu`
27+
extra installs the CUDA 13 JAX runtime, including CUDA, cuDNN, and NCCL Python wheels:
4728

4829
```bash
4930
uv sync --extra=gpu
5031
```
5132

52-
See [JAX's installation guide](https://jax.readthedocs.io/en/latest/installation.html) for more options.
33+
If you install a local CUDA toolkit for custom kernels, use CUDA 13 and keep older CUDA libraries
34+
out of `LD_LIBRARY_PATH` so they do not override the JAX wheel libraries.
35+
36+
See [JAX's installation guide](https://docs.jax.dev/en/latest/installation.html) for more options.
5337

5438
!!! tip
5539
If you are using a DGX Spark or similar machine with unified memory, you may need to dramatically reduce the memory that XLA preallocates for itself. You can do this by setting the `XLA_PYTHON_CLIENT_MEM_FRACTION` variable, to something like 0.5:

lib/iris/OPS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ State dir: `gs://marin-us-central2/iris/<cluster>/state/` — contains `bundles/
255255
## CoreWeave (GPU) Operations
256256

257257
Use `lib/iris/examples/coreweave-*.yaml` for CoreWeave scale group configurations.
258+
Marin's `gpu` extra installs the JAX CUDA 13 wheel stack from PyPI. CoreWeave
259+
GPU nodes must expose NVIDIA driver 580 or newer; `nvidia-smi` should report
260+
CUDA 13.x. CPU, TPU, and vLLM jobs use separate extras and should not carry the
261+
CUDA 13 JAX runtime.
258262

259263
### Connecting
260264

lib/levanter/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ Homepage = "https://github.com/stanford-crfm/levanter"
8080

8181
[project.optional-dependencies]
8282
gpu = [
83-
"jax[cuda12]>=0.9.2",
84-
# JAX 0.9.2 all-to-all fails on CW H100s with NCCL 2.27.x. Keep this floor
85-
# until the JAX CUDA deps or the top-level lock exclude the bad NCCL line.
86-
"nvidia-nccl-cu12>=2.28.3; sys_platform == 'linux'",
83+
"jax[cuda13]>=0.9.2",
84+
# B200 emits a cuBLAS warning with older CUDA 13 cuBLAS builds.
85+
"nvidia-cublas>=13.2.0.9; sys_platform == 'linux'",
86+
# Preserve the CoreWeave H100 all-to-all guard under CUDA 13.
87+
"nvidia-nccl-cu13>=2.28.3; sys_platform == 'linux'",
8788
]
8889
tpu = ["jax==0.9.2", "jaxlib==0.9.2", "libtpu==0.0.38"]
8990
torch_test = [

lib/marin/pyproject.toml

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ conflicts = [
119119
{ extra = "vllm" },
120120
{ extra = "cpu" },
121121
],
122-
[
123-
{ extra = "vllm" },
124-
{ extra = "cuda12" },
125-
],
126122
[
127123
# The vllm extra ships vllm-tpu only, so it must use CPU/torch_xla
128124
# torch rather than the cu128-pinned torch from the gpu extra.
@@ -135,12 +131,11 @@ conflicts = [
135131
[project.optional-dependencies]
136132

137133
gpu = [
138-
"jax[cuda12]==0.9.2",
139-
# JAX 0.9.2 all-to-all fails on CW H100s with NCCL 2.27.x. This can be
140-
# removed once the resolved GPU stack no longer admits NCCL <2.28.3.
141-
"nvidia-nccl-cu12>=2.28.3; sys_platform == 'linux'",
142-
# torch 2.10.0+cu128 pins nvidia-nccl-cu12==2.27.5, which reintroduces the
143-
# bad all-to-all stack above. torch 2.11.0+cu128 resolves NCCL 2.28.9.
134+
"jax[cuda13]==0.9.2",
135+
# B200 emits a cuBLAS warning with older CUDA 13 cuBLAS builds.
136+
"nvidia-cublas>=13.2.0.9; sys_platform == 'linux'",
137+
# Preserve the CoreWeave H100 all-to-all guard under CUDA 13.
138+
"nvidia-nccl-cu13>=2.28.3; sys_platform == 'linux'",
144139
"torch==2.11.0",
145140
"torchvision==0.26.0",
146141
]
@@ -219,13 +214,11 @@ torchvision = [
219214
{ index = "pytorch-cpu", extra = "cpu" },
220215
{ index = "pytorch-cpu", extra = "tpu" },
221216
{ index = "pytorch-cpu", extra = "vllm" },
222-
# The GPU extra pins a plain torchvision version so non-Linux platforms can
223-
# use PyPI wheels. Only Linux GPU installs should route to PyTorch's cu128
224-
# index for the matching CUDA wheel.
217+
# The GPU extra uses PyTorch cu128 wheels; JAX CUDA 13 packages come from PyPI.
225218
{ index = "pytorch-cu128", extra = "gpu", marker = "sys_platform == 'linux'" },
226219
]
227220
resiliparse = { index = "marin-resiliparse" }
228-
# Use CUDA PyTorch for --extra=gpu on Linux, CPU PyTorch for TPU/CPU/vLLM builds.
221+
# Use PyTorch CUDA 12.8 wheels for --extra=gpu on Linux, CPU PyTorch for TPU/CPU/vLLM builds.
229222
torch = [
230223
{ index = "pytorch-cu128", extra = "gpu", marker = "sys_platform == 'linux'" },
231224
{ index = "pytorch-cpu", extra = "cpu" },

tests/test_dependency_extras.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
import re
6+
import subprocess
7+
from pathlib import Path
8+
9+
import pytest
10+
11+
REPO_ROOT = Path(__file__).resolve().parents[1]
12+
REQUIREMENT = re.compile(r"^([A-Za-z0-9_.-]+)==([^ ;]+)")
13+
CUDA_RUNTIME_PREFIXES = ("jax-cuda", "nvidia-")
14+
15+
16+
def export_packages(package: str, extra: str) -> dict[str, str]:
17+
result = subprocess.run(
18+
[
19+
"uv",
20+
"export",
21+
"--package",
22+
package,
23+
"--extra",
24+
extra,
25+
"--no-dev",
26+
"--frozen",
27+
"--no-emit-project",
28+
"--no-emit-workspace",
29+
"--no-header",
30+
"--no-annotate",
31+
"--no-hashes",
32+
],
33+
cwd=REPO_ROOT,
34+
check=True,
35+
capture_output=True,
36+
text=True,
37+
env={**os.environ, "UV_NO_PROGRESS": "1"},
38+
)
39+
packages = {}
40+
for line in result.stdout.splitlines():
41+
match = REQUIREMENT.match(line)
42+
if match:
43+
packages[match.group(1).lower().replace("_", "-")] = match.group(2)
44+
return packages
45+
46+
47+
def cuda_runtime_packages(packages: dict[str, str]) -> list[str]:
48+
return sorted(name for name in packages if name.startswith(CUDA_RUNTIME_PREFIXES))
49+
50+
51+
@pytest.mark.parametrize("package", ["marin", "marin-levanter"])
52+
def test_gpu_extra_exports_cuda13_jax_runtime(package: str):
53+
"""The GPU extra is the resolver boundary; this catches accidental reverts to JAX CUDA 12."""
54+
packages = export_packages(package, "gpu")
55+
56+
assert "jax-cuda13-plugin" in packages
57+
assert "jax-cuda13-pjrt" in packages
58+
assert "jax-cuda12-plugin" not in packages
59+
assert "jax-cuda12-pjrt" not in packages
60+
61+
62+
@pytest.mark.parametrize("extra", ["cpu", "tpu", "vllm"])
63+
def test_non_gpu_extras_do_not_export_cuda_runtime(extra: str):
64+
"""CPU/TPU/vLLM jobs should not inherit JAX CUDA or NVIDIA runtime wheels."""
65+
packages = export_packages("marin", extra)
66+
67+
assert cuda_runtime_packages(packages) == []
68+
69+
70+
def test_levanter_tpu_extra_does_not_export_cuda_runtime():
71+
"""Levanter TPU jobs should not inherit JAX CUDA or NVIDIA runtime wheels."""
72+
packages = export_packages("marin-levanter", "tpu")
73+
74+
assert cuda_runtime_packages(packages) == []

0 commit comments

Comments
 (0)