Use JAX 0.10 CUDA 13 for GPU installs#5428
Conversation
f78ab18 to
c938b3d
Compare
ff8a967 to
96e3bb0
Compare
Switch Marin and Levanter GPU extras from JAX CUDA 12 to JAX 0.10 CUDA 13. JAX 0.9.2 reproduced an H100x8 CUDA 13 profiler crash; JAX 0.10 passed the repros and H100x8 canary. CPU, TPU, and vLLM stay on JAX 0.9.2 until tpu-inference can unpin JAX. Part of #5427
96e3bb0 to
d589513
Compare
|
@claude Code review please. |
|
Claude finished @yonromai's task in 3m 16s —— View job Code review🤖 Reviewed What I think is well done
Suggestions (non-blocking unless flagged)
Risk recap (matches your body, no new findings)
Bottom lineLooks good. The only ask I'd push back on is the asymmetric |
Summary
Switch Marin and Levanter GPU extras from JAX CUDA 12 to JAX 0.10 CUDA 13. CPU, TPU, and vLLM stay on JAX/JAXlib 0.9.2 while tpu-inference still pins the older stack.
Part of #5427
Why JAX 0.10
The CUDA 13 migration reproduced an H100x8 profiler crash with JAX 0.9.2. JAX 0.10 passed the same raw profiler repro, the stock Levanter profiler window, and the full H100x8 CoreWeave canary.
This GPU/TPU JAX version split is expected to be temporary. The planned tpu-inference unpin should let TPU paths converge on the newer JAX stack later. The root workspace
conflictsblock is the resolver expression of that temporary split; removing each entry madeuv lockfail on an incompatible JAX 0.10 GPU + JAX 0.9 CPU/TPU/vLLM solve. Levanter now mirrors Marin's<0.11core JAX cap so non-locked installs do not drift to an untested JAX 0.11 line.Torch CUDA
Torch remains on PyTorch's CUDA 12.8 Linux GPU wheels for now. The exact pinned PyTorch CUDA 13 wheels exist, but
torch==2.11.0+cu130pullscuda-toolkit[cublas]==13.0.2, which pinsnvidia-cublas==13.1.0.3.*and conflicts with the B200 guardnvidia-cublas>=13.2.0.9.Safety Level
Strong for the default CoreWeave H100 training/profiler path: the full H100x8 canary passed with the default profiler window.
TPU code has CI coverage:
levanter-tpu-testspassed on TPU hardware with--extra tpu,JAX_PLATFORMS=tpu,cpu,PJRT_DEVICE=TPU, and the JAX 0.9.2 TPU stack.Bounded for GH200/B200: direct JAX CUDA 13 device smokes passed on both rows, but GH200/B200 training/profiler smokes were not run.
Latest GitHub PR checks on the current head are green except
cw-ci-test. That CoreWeave Iris smoke failure repeats the earlier Iris controller port-forward timeout intest_cancel_job_releases_resources; it is not a JAX CUDA runtime failure.Exact validation commands, run IDs, and residual risk
Validated commits:
7b53485cd.d01f0ec8b.NamedSharding-specific. Later commits changed comments/docs, Torch source-map rationale, root conflict wording, and Levanter's core JAX upper bound metadata.Local and CI validation:
uv lock: passed after the Levanter upper-bound change;uv.lockchanged only the Levanterjaxspecifier metadata.uv lock --check: passed.uv run --package marin-levanter pytest lib/levanter/tests/kernels/test_pallas_autotune_utils.py -q: passed, 5 tests, on the same runtime code as current head../infra/pre-commit.py --fix lib/levanter/src/levanter/kernels/pallas/autotune_utils.py: passed, on the same runtime code as current head../infra/pre-commit.py --all-files --fix: passed after the Levanter upper-bound commit.Align Levanter JAX upper bound: passed.7b53485cd: 35 successful check runs plus ReadTheDocs success; 4 skipped;cw-ci-testfailed. Passing checks includelevanter-tpu-tests,cloud-smoke-test,marin-integration,marin-lint,marin-unit,levanter-unit, and docs.25406183008for7b53485cd:cw-ci-testfailed intests/integration/iris/test_iris_integration.py::test_cancel_job_releases_resourcesafter the local Iris controller port-forward began returningConnection refusedwhile polling a follow-up job state.25402431989for96e3bb07a: failed the same test with the same controller port-forward symptom. Run25403714756ford589513b5was cancelled before jobs were created.Dependency/export validation after the Levanter upper-bound commit:
gpuexport resolvedjax==0.10.0,jaxlib==0.10.0,jax-cuda13-*==0.10.0,nvidia-cublas==13.4.1.1,nvidia-cuda-runtime==13.0.96,nvidia-nccl-cu13==2.28.9,torch==2.11.0+cu128, andtorchvision==0.26.0+cu128on Linux.tpuexport resolvedjax==0.9.2,jaxlib==0.9.2,libtpu==0.0.38, and CPU Torch wheels.cpuandvllmexports stayed onjax==0.9.2/jaxlib==0.9.2and did not pull the JAX CUDA/NVIDIA runtime.gpuexport resolvedjax==0.10.0,jaxlib==0.10.0, CUDA 13 JAX packages,nvidia-cublas==13.4.1.1, andnvidia-nccl-cu13==2.28.9.tpuexport resolvedjax==0.9.2,jaxlib==0.9.2, andlibtpu==0.0.38.printf '%s\n' 'torch==2.11.0+cu130' 'nvidia-cublas>=13.2.0.9' | uv pip compile - --python-platform x86_64-manylinux_2_28 --index-url https://download.pytorch.org/whl/cu130 --extra-index-url https://pypi.org/simple --index-strategy unsafe-best-matchfailed becausetorch==2.11.0+cu130depends oncuda-toolkit[cublas]==13.0.2, which pinsnvidia-cublas==13.1.0.3.Bounded live JAX CUDA 13 device smokes:
uv run iris --cluster=coreweave-ci job run --job-name c13-jax010-smoke-h100-d01f0ec8 --enable-extra-resources --cpu=4 --memory=16G --disk=32G --gpu=H100x1 --extra=gpu -- python -c <nvidia-smi+jax+matmul probe>./romain/c13-jax010-smoke-h100-d01f0ec8passed. Driver595.45.04, CUDA13.2, JAX/JAXlib/JAX CUDA13 packages0.10.0, GPU backend, 8x8 matmul sum512.0, no CUDA/cuBLAS/NCCL warnings in captured logs.uv run iris --cluster=coreweave-rno2a job run --job-name c13-jax010-smoke-gh200-d01f0ec8 --enable-extra-resources --cpu=4 --memory=16G --disk=32G --gpu=H200x1 --extra=gpu -- python -c <nvidia-smi+jax+matmul probe>./romain/c13-jax010-smoke-gh200-d01f0ec8passed. Driver595.45.04, CUDA13.2, JAX/JAXlib/JAX CUDA13 packages0.10.0, GPU backend, 8x8 matmul sum512.0, no CUDA/cuBLAS/NCCL warnings in captured logs.uv run iris --cluster=coreweave-usw09b job run --job-name c13-jax010-smoke-b200-d01f0ec8 --enable-extra-resources --cpu=4 --memory=16G --disk=32G --gpu=B200x1 --extra=gpu -- python -c <nvidia-smi+jax+matmul probe>./romain/c13-jax010-smoke-b200-d01f0ec8passed. Driver595.45.04, CUDA13.2, JAX/JAXlib/JAX CUDA13 packages0.10.0, GPU backend, 8x8 matmul sum512.0, no CUDA/cuBLAS/NCCL warnings in captured logs.Profiler repro and H100x8 canary validation:
/romain/c13p-raw-s10n1-g; rawjax.profiler.start_trace(..., create_perfetto_trace=False)around training step 10 failed withCUDA_ERROR_LAUNCH_FAILEDafterPROFILE_STOP_AFTER./romain/c13p-jax010-raw-s10n1-a; passed 35/35 steps with raw profiler around training step 10 and printedGRUG_SYNTHETIC_OK./romain/c13p-jax010-stock-s5n25-a; passed 35/35 steps withstart_step=5,num_steps=25, wroteperfetto_trace.json.gz, and printedGRUG_SYNTHETIC_OK.gh workflow run marin-canary-ferry-cw.yaml --repo marin-community/marin --ref agent/20260504-fix-5427 -f multi_host=false.25394978124, passed ond01f0ec8b./runner/iris-run-job-20260505-184155, succeeded./runner/iris-run-job-20260505-184155/grug-train-canary-gpu-25394978124-1, succeeded.https://wandb.ai/marin-community/marin/runs/canary-gpu-25394978124-1.99 >= 40, final loss6.5252 <= 8.0.trace.json.gz; Perfetto conversion wroteperfetto_trace.json.gz; W&B profiler artifact upload passed. Profile artifact:marin-community/marin/jax-profile-step-5-30:v13.Residual risk:
cw-ci-testis red on the current head due to the repeated Iris controller port-forward failure described above.