Skip to content

Commit ff8a967

Browse files
committed
Use JAX 0.10 CUDA 13 for GPU installs
Switch Marin and Levanter GPU extras from JAX CUDA 12 to JAX 0.10 CUDA 13. JAX 0.9.2 reproduced an H100x8 CUDA 13 profiler crash; JAX 0.10 passed the repros and H100x8 canary. CPU, TPU, and vLLM stay on JAX 0.9.2 until tpu-inference can unpin JAX. Part of #5427
1 parent a92b2a6 commit ff8a967

8 files changed

Lines changed: 1944 additions & 1184 deletions

File tree

docs/tutorials/local-gpu.md

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,32 @@ Similar steps will let you run Marin on a cloud GPU environment under Iris (the
88

99
Make sure you've followed the [installation guide](installation.md) to do the basic installation.
1010

11-
In addition to the prerequisites from the basic installation, we have GPU-specific dependencies:
11+
In addition to the prerequisites from the basic installation, we have one GPU-specific system dependency:
1212

13-
- CUDA Toolkit (version 12.1 or higher)
14-
- cuDNN (version 9.1 or higher)
13+
- NVIDIA driver 580 or newer
1514

1615
We assume you are running Ubuntu 24.04.
1716

18-
## CUDA installation
17+
## NVIDIA driver and runtime
1918

20-
Install CUDA 12.9.0:
19+
Install an NVIDIA driver that supports CUDA 13. Verify that the driver is at least 580 and that
20+
`nvidia-smi` reports CUDA 13.x:
2121

2222
```bash
23-
wget https://developer.download.nvidia.com/compute/cuda/12.9.0/local_installers/cuda_12.9.0_575.51.03_linux.run
24-
sudo sh cuda_12.9.0_575.51.03_linux.run
23+
nvidia-smi
2524
```
2625

27-
Install cuDNN 9.9.0 (Instructions from [NVIDIA's cuDNN download page](https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=24.04&target_type=deb_local)):
28-
29-
```bash
30-
wget https://developer.download.nvidia.com/compute/cudnn/9.10.0/local_installers/cudnn-local-repo-ubuntu2404-9.10.0_1.0-1_amd64.deb
31-
sudo dpkg -i cudnn-local-repo-ubuntu2404-9.10.0_1.0-1_amd64.deb
32-
sudo cp /var/cudnn-local-repo-ubuntu2404-9.10.0/cudnn-*-keyring.gpg /usr/share/keyrings/
33-
sudo apt-get update
34-
sudo apt-get -y install cudnn
35-
sudo apt-get -y install cudnn-cuda-12
36-
sudo apt-get -y install nvidia-cuda-toolkit
37-
```
38-
39-
Verify your setup by checking the CUDA version:
40-
41-
```bash
42-
nvcc --version
43-
```
44-
45-
Marin uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as a core library.
46-
Install Python dependencies for CUDA 12.x via uv:
26+
Marin uses [JAX](https://docs.jax.dev/en/latest/index.html) as a core library. The `gpu`
27+
extra installs the CUDA 13 JAX runtime, including CUDA, cuDNN, and NCCL Python wheels:
4728

4829
```bash
4930
uv sync --extra=gpu
5031
```
5132

52-
See [JAX's installation guide](https://jax.readthedocs.io/en/latest/installation.html) for more options.
33+
If you install a local CUDA toolkit for custom kernels, use CUDA 13 and keep older CUDA libraries
34+
out of `LD_LIBRARY_PATH` so they do not override the JAX wheel libraries.
35+
36+
See [JAX's installation guide](https://docs.jax.dev/en/latest/installation.html) for more options.
5337

5438
!!! tip
5539
If you are using a DGX Spark or similar machine with unified memory, you may need to dramatically reduce the memory that XLA preallocates for itself. You can do this by setting the `XLA_PYTHON_CLIENT_MEM_FRACTION` variable, to something like 0.5:

lib/iris/docs/coreweave.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ Do not change the GH200 row to `GH200x1`: the RNO2A pool currently accepts
210210
Before the full GPU canary, run one tiny direct JAX job for each row. It should
211211
prove `nvidia-smi`, GPU-backed JAX, and a tiny matmul.
212212

213+
Marin's `gpu` extra installs the JAX CUDA 13 wheel stack from PyPI. CoreWeave
214+
GPU nodes must expose NVIDIA driver 580 or newer; `nvidia-smi` should report
215+
CUDA 13.x. CPU, TPU, and vLLM jobs use separate extras and should not carry the
216+
CUDA 13 JAX runtime.
217+
213218
### KubernetesProvider Operations
214219

215220
On CoreWeave, there are no persistent worker daemons. The controller dispatches

lib/levanter/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ Homepage = "https://github.com/stanford-crfm/levanter"
8080

8181
[project.optional-dependencies]
8282
gpu = [
83-
"jax[cuda12]>=0.9.2",
84-
# JAX 0.9.2 all-to-all fails on CW H100s with NCCL 2.27.x. Keep this floor
85-
# until the JAX CUDA deps or the top-level lock exclude the bad NCCL line.
86-
"nvidia-nccl-cu12>=2.28.3; sys_platform == 'linux'",
83+
"jax[cuda13]==0.10.0",
84+
# B200 emits a cuBLAS warning with older CUDA 13 cuBLAS builds.
85+
"nvidia-cublas>=13.2.0.9; sys_platform == 'linux'",
86+
# Preserve the CoreWeave H100 all-to-all guard under CUDA 13.
87+
"nvidia-nccl-cu13>=2.28.3; sys_platform == 'linux'",
8788
]
8889
tpu = ["jax==0.9.2", "jaxlib==0.9.2", "libtpu==0.0.38"]
8990
torch_test = [

lib/levanter/src/levanter/kernels/pallas/autotune_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import jax
1010
from jax import core as jax_core
1111
from jax._src import mesh as mesh_lib
12-
from jax.sharding import NamedSharding
12+
from jax.sharding import AxisType, NamedSharding
1313

1414

1515
_AUTOTUNE_THREAD_POOL = ThreadPoolExecutor(max_workers=1, thread_name_prefix="pallas_autotune")
@@ -53,8 +53,15 @@ def hlo_sharding_of(value: jax.Array):
5353
return None
5454

5555

56+
def _named_sharding_uses_manual_axes(sharding: NamedSharding) -> bool:
57+
return any(axis_type is AxisType.Manual for axis_type in sharding.mesh.axis_types)
58+
59+
5660
def value_uses_manual_sharding(value: jax.Array) -> bool:
5761
"""Detect shard_map-local tracer values that carry manual sharding."""
62+
sharding = sharding_of(value)
63+
if isinstance(sharding, NamedSharding) and _named_sharding_uses_manual_axes(sharding):
64+
return True
5865
hlo_sharding = hlo_sharding_of(value)
5966
return hlo_sharding is not None and hlo_sharding.is_manual()
6067

lib/levanter/tests/kernels/test_pallas_autotune_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import threading
55

66
import jax
7-
from jax._src import pjit
87
import jax.numpy as jnp
98
import numpy as np
109
import pytest
@@ -85,8 +84,6 @@ def test_shape_dtype_struct_for_benchmark_drops_manual_sharding_from_shard_map_t
8584
def _capture(local_x):
8685
seen_manual.append(autotune_utils.value_uses_manual_sharding(local_x))
8786
seen_shapes.append(local_x.shape)
88-
with pytest.raises(AssertionError):
89-
pjit.pjit_check_aval_sharding([local_x.aval.sharding], [local_x.aval], ["x"], "arg", False)
9087
seen_structs.append(autotune_utils.shape_dtype_struct_for_benchmark(local_x))
9188
return local_x
9289

lib/marin/pyproject.toml

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"google-cloud-storage",
2727
"google-cloud-storage-transfer",
2828
"marin-haliax",
29-
"jax==0.9.2",
29+
"jax>=0.9.2,<0.11",
3030
"jaxopt>=0.8.3",
3131
"marin-levanter[serve]",
3232
"lxml[html_clean]",
@@ -119,10 +119,6 @@ conflicts = [
119119
{ extra = "vllm" },
120120
{ extra = "cpu" },
121121
],
122-
[
123-
{ extra = "vllm" },
124-
{ extra = "cuda12" },
125-
],
126122
[
127123
# The vllm extra ships vllm-tpu only, so it must use CPU/torch_xla
128124
# torch rather than the cu128-pinned torch from the gpu extra.
@@ -135,12 +131,11 @@ conflicts = [
135131
[project.optional-dependencies]
136132

137133
gpu = [
138-
"jax[cuda12]==0.9.2",
139-
# JAX 0.9.2 all-to-all fails on CW H100s with NCCL 2.27.x. This can be
140-
# removed once the resolved GPU stack no longer admits NCCL <2.28.3.
141-
"nvidia-nccl-cu12>=2.28.3; sys_platform == 'linux'",
142-
# torch 2.10.0+cu128 pins nvidia-nccl-cu12==2.27.5, which reintroduces the
143-
# bad all-to-all stack above. torch 2.11.0+cu128 resolves NCCL 2.28.9.
134+
"jax[cuda13]==0.10.0",
135+
# B200 emits a cuBLAS warning with older CUDA 13 cuBLAS builds.
136+
"nvidia-cublas>=13.2.0.9; sys_platform == 'linux'",
137+
# Preserve the CoreWeave H100 all-to-all guard under CUDA 13.
138+
"nvidia-nccl-cu13>=2.28.3; sys_platform == 'linux'",
144139
"torch==2.11.0",
145140
"torchvision==0.26.0",
146141
]
@@ -188,6 +183,8 @@ vizier = [
188183
]
189184

190185
vllm = [
186+
"jax==0.9.2",
187+
"jaxlib==0.9.2",
191188
"vllm-tpu==0.18.0",
192189
"tpu-inference==0.18.0",
193190
"triton==3.6.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
@@ -219,13 +216,11 @@ torchvision = [
219216
{ index = "pytorch-cpu", extra = "cpu" },
220217
{ index = "pytorch-cpu", extra = "tpu" },
221218
{ index = "pytorch-cpu", extra = "vllm" },
222-
# The GPU extra pins a plain torchvision version so non-Linux platforms can
223-
# use PyPI wheels. Only Linux GPU installs should route to PyTorch's cu128
224-
# index for the matching CUDA wheel.
219+
# The GPU extra uses PyTorch cu128 wheels; JAX CUDA 13 packages come from PyPI.
225220
{ index = "pytorch-cu128", extra = "gpu", marker = "sys_platform == 'linux'" },
226221
]
227222
resiliparse = { index = "marin-resiliparse" }
228-
# Use CUDA PyTorch for --extra=gpu on Linux, CPU PyTorch for TPU/CPU/vLLM builds.
223+
# Use PyTorch CUDA 12.8 wheels for --extra=gpu on Linux, CPU PyTorch for TPU/CPU/vLLM builds.
229224
torch = [
230225
{ index = "pytorch-cu128", extra = "gpu", marker = "sys_platform == 'linux'" },
231226
{ index = "pytorch-cpu", extra = "cpu" },

pyproject.toml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,32 @@ override-dependencies = [
3737
"datasets>=3.1.0,<5.0.0",
3838
"equinox>=0.11.10", # Override vizier's pin for modern JAX compatibility
3939
]
40+
conflicts = [
41+
[
42+
{ package = "marin-levanter", extra = "gpu" },
43+
{ package = "marin", extra = "cpu" },
44+
],
45+
[
46+
{ package = "marin-levanter", extra = "gpu" },
47+
{ package = "marin", extra = "tpu" },
48+
],
49+
[
50+
{ package = "marin-levanter", extra = "gpu" },
51+
{ package = "marin", extra = "vllm" },
52+
],
53+
[
54+
{ package = "marin", extra = "gpu" },
55+
{ package = "marin-levanter", extra = "tpu" },
56+
],
57+
[
58+
{ package = "marin", extra = "gpu" },
59+
{ package = "marin-fray", group = "fray_tpu_test" },
60+
],
61+
[
62+
{ package = "marin-levanter", extra = "gpu" },
63+
{ package = "marin-fray", group = "fray_tpu_test" },
64+
],
65+
]
4066

4167
[tool.uv.workspace]
4268
members = [

0 commit comments

Comments
 (0)