Skip to content

Use JAX 0.10 CUDA 13 for GPU installs#5428

Merged
yonromai merged 2 commits into
mainfrom
agent/20260504-fix-5427
May 6, 2026
Merged

Use JAX 0.10 CUDA 13 for GPU installs#5428
yonromai merged 2 commits into
mainfrom
agent/20260504-fix-5427

Conversation

@yonromai
Copy link
Copy Markdown
Contributor

@yonromai yonromai commented May 5, 2026

Summary

Switch Marin and Levanter GPU extras from JAX CUDA 12 to JAX 0.10 CUDA 13. CPU, TPU, and vLLM stay on JAX/JAXlib 0.9.2 while tpu-inference still pins the older stack.

Part of #5427

Why JAX 0.10

The CUDA 13 migration reproduced an H100x8 profiler crash with JAX 0.9.2. JAX 0.10 passed the same raw profiler repro, the stock Levanter profiler window, and the full H100x8 CoreWeave canary.

This GPU/TPU JAX version split is expected to be temporary. The planned tpu-inference unpin should let TPU paths converge on the newer JAX stack later. The root workspace conflicts block is the resolver expression of that temporary split; removing each entry made uv lock fail on an incompatible JAX 0.10 GPU + JAX 0.9 CPU/TPU/vLLM solve. Levanter now mirrors Marin's <0.11 core JAX cap so non-locked installs do not drift to an untested JAX 0.11 line.

Torch CUDA

Torch remains on PyTorch's CUDA 12.8 Linux GPU wheels for now. The exact pinned PyTorch CUDA 13 wheels exist, but torch==2.11.0+cu130 pulls cuda-toolkit[cublas]==13.0.2, which pins nvidia-cublas==13.1.0.3.* and conflicts with the B200 guard nvidia-cublas>=13.2.0.9.

Safety Level

Strong for the default CoreWeave H100 training/profiler path: the full H100x8 canary passed with the default profiler window.

TPU code has CI coverage: levanter-tpu-tests passed on TPU hardware with --extra tpu, JAX_PLATFORMS=tpu,cpu, PJRT_DEVICE=TPU, and the JAX 0.9.2 TPU stack.

Bounded for GH200/B200: direct JAX CUDA 13 device smokes passed on both rows, but GH200/B200 training/profiler smokes were not run.

Latest GitHub PR checks on the current head are green except cw-ci-test. That CoreWeave Iris smoke failure repeats the earlier Iris controller port-forward timeout in test_cancel_job_releases_resources; it is not a JAX CUDA runtime failure.

Exact validation commands, run IDs, and residual risk

Validated commits:

  • Current PR head: 7b53485cd.
  • CUDA 13 runtime/live canary commit before final cleanup commits: d01f0ec8b.
  • Post-canary runtime code delta: a one-file Pallas cleanup that makes the JAX 0.10 manual-axis helper private and NamedSharding-specific. Later commits changed comments/docs, Torch source-map rationale, root conflict wording, and Levanter's core JAX upper bound metadata.

Local and CI validation:

  • uv lock: passed after the Levanter upper-bound change; uv.lock changed only the Levanter jax specifier metadata.
  • uv lock --check: passed.
  • uv run --package marin-levanter pytest lib/levanter/tests/kernels/test_pallas_autotune_utils.py -q: passed, 5 tests, on the same runtime code as current head.
  • ./infra/pre-commit.py --fix lib/levanter/src/levanter/kernels/pallas/autotune_utils.py: passed, on the same runtime code as current head.
  • ./infra/pre-commit.py --all-files --fix: passed after the Levanter upper-bound commit.
  • Commit hook during Align Levanter JAX upper bound: passed.
  • GitHub PR check rollup for 7b53485cd: 35 successful check runs plus ReadTheDocs success; 4 skipped; cw-ci-test failed. Passing checks include levanter-tpu-tests, cloud-smoke-test, marin-integration, marin-lint, marin-unit, levanter-unit, and docs.
  • Latest CoreWeave Iris smoke run 25406183008 for 7b53485cd: cw-ci-test failed in tests/integration/iris/test_iris_integration.py::test_cancel_job_releases_resources after the local Iris controller port-forward began returning Connection refused while polling a follow-up job state.
  • Previous CoreWeave Iris smoke run 25402431989 for 96e3bb07a: failed the same test with the same controller port-forward symptom. Run 25403714756 for d589513b5 was cancelled before jobs were created.

Dependency/export validation after the Levanter upper-bound commit:

  • Marin gpu export resolved jax==0.10.0, jaxlib==0.10.0, jax-cuda13-*==0.10.0, nvidia-cublas==13.4.1.1, nvidia-cuda-runtime==13.0.96, nvidia-nccl-cu13==2.28.9, torch==2.11.0+cu128, and torchvision==0.26.0+cu128 on Linux.
  • Marin tpu export resolved jax==0.9.2, jaxlib==0.9.2, libtpu==0.0.38, and CPU Torch wheels.
  • Marin cpu and vllm exports stayed on jax==0.9.2 / jaxlib==0.9.2 and did not pull the JAX CUDA/NVIDIA runtime.
  • Levanter gpu export resolved jax==0.10.0, jaxlib==0.10.0, CUDA 13 JAX packages, nvidia-cublas==13.4.1.1, and nvidia-nccl-cu13==2.28.9.
  • Levanter tpu export resolved jax==0.9.2, jaxlib==0.9.2, and libtpu==0.0.38.
  • PyTorch CUDA 13 check: printf '%s\n' 'torch==2.11.0+cu130' 'nvidia-cublas>=13.2.0.9' | uv pip compile - --python-platform x86_64-manylinux_2_28 --index-url https://download.pytorch.org/whl/cu130 --extra-index-url https://pypi.org/simple --index-strategy unsafe-best-match failed because torch==2.11.0+cu130 depends on cuda-toolkit[cublas]==13.0.2, which pins nvidia-cublas==13.1.0.3.

Bounded live JAX CUDA 13 device smokes:

  • H100 command: uv run iris --cluster=coreweave-ci job run --job-name c13-jax010-smoke-h100-d01f0ec8 --enable-extra-resources --cpu=4 --memory=16G --disk=32G --gpu=H100x1 --extra=gpu -- python -c <nvidia-smi+jax+matmul probe>.
  • H100 result: job /romain/c13-jax010-smoke-h100-d01f0ec8 passed. Driver 595.45.04, CUDA 13.2, JAX/JAXlib/JAX CUDA13 packages 0.10.0, GPU backend, 8x8 matmul sum 512.0, no CUDA/cuBLAS/NCCL warnings in captured logs.
  • GH200 command: uv run iris --cluster=coreweave-rno2a job run --job-name c13-jax010-smoke-gh200-d01f0ec8 --enable-extra-resources --cpu=4 --memory=16G --disk=32G --gpu=H200x1 --extra=gpu -- python -c <nvidia-smi+jax+matmul probe>.
  • GH200 result: job /romain/c13-jax010-smoke-gh200-d01f0ec8 passed. Driver 595.45.04, CUDA 13.2, JAX/JAXlib/JAX CUDA13 packages 0.10.0, GPU backend, 8x8 matmul sum 512.0, no CUDA/cuBLAS/NCCL warnings in captured logs.
  • B200 command: uv run iris --cluster=coreweave-usw09b job run --job-name c13-jax010-smoke-b200-d01f0ec8 --enable-extra-resources --cpu=4 --memory=16G --disk=32G --gpu=B200x1 --extra=gpu -- python -c <nvidia-smi+jax+matmul probe>.
  • B200 result: job /romain/c13-jax010-smoke-b200-d01f0ec8 passed. Driver 595.45.04, CUDA 13.2, JAX/JAXlib/JAX CUDA13 packages 0.10.0, GPU backend, 8x8 matmul sum 512.0, no CUDA/cuBLAS/NCCL warnings in captured logs.

Profiler repro and H100x8 canary validation:

  • JAX 0.9.2 CUDA 13 repro: /romain/c13p-raw-s10n1-g; raw jax.profiler.start_trace(..., create_perfetto_trace=False) around training step 10 failed with CUDA_ERROR_LAUNCH_FAILED after PROFILE_STOP_AFTER.
  • JAX 0.10.0 raw profiler check: /romain/c13p-jax010-raw-s10n1-a; passed 35/35 steps with raw profiler around training step 10 and printed GRUG_SYNTHETIC_OK.
  • JAX 0.10.0 stock Levanter profiler check: /romain/c13p-jax010-stock-s5n25-a; passed 35/35 steps with start_step=5, num_steps=25, wrote perfetto_trace.json.gz, and printed GRUG_SYNTHETIC_OK.
  • Full CoreWeave GPU canary command: gh workflow run marin-canary-ferry-cw.yaml --repo marin-community/marin --ref agent/20260504-fix-5427 -f multi_host=false.
  • GitHub Actions run: 25394978124, passed on d01f0ec8b.
  • Iris parent: /runner/iris-run-job-20260505-184155, succeeded.
  • Iris child: /runner/iris-run-job-20260505-184155/grug-train-canary-gpu-25394978124-1, succeeded.
  • W&B run: https://wandb.ai/marin-community/marin/runs/canary-gpu-25394978124-1.
  • Canary metrics: steps completed 99 >= 40, final loss 6.5252 <= 8.0.
  • Canary profiler path: passed. Levanter started/stopped the default profiler window; JAX loaded trace.json.gz; Perfetto conversion wrote perfetto_trace.json.gz; W&B profiler artifact upload passed. Profile artifact: marin-community/marin/jax-profile-step-5-30:v13.
  • Profiler caveat: profile summary warned that the trace likely hit the 1,000,000 complete-event export cap, so detailed profile performance statistics should be treated cautiously.

Residual risk:

  • GH200 and B200 training/profiler smokes were not run.
  • The full live canary was not rerun after the Pallas cleanup/squash; later changes do not alter resolved packages or runtime code.
  • The CoreWeave Iris smoke cw-ci-test is red on the current head due to the repeated Iris controller port-forward failure described above.

@yonromai yonromai added the agent-generated Created by automation/agent label May 5, 2026
@yonromai yonromai force-pushed the agent/20260504-fix-5427 branch from f78ab18 to c938b3d Compare May 5, 2026 01:23
@yonromai yonromai changed the title Move GPU JAX runtime to CUDA 13 Use JAX 0.10 CUDA 13 for GPU installs May 5, 2026
@yonromai yonromai force-pushed the agent/20260504-fix-5427 branch 2 times, most recently from ff8a967 to 96e3bb0 Compare May 5, 2026 21:08
Switch Marin and Levanter GPU extras from JAX CUDA 12 to JAX 0.10 CUDA 13. JAX 0.9.2 reproduced an H100x8 CUDA 13 profiler crash; JAX 0.10 passed the repros and H100x8 canary. CPU, TPU, and vLLM stay on JAX 0.9.2 until tpu-inference can unpin JAX.

Part of #5427
@yonromai yonromai force-pushed the agent/20260504-fix-5427 branch from 96e3bb0 to d589513 Compare May 5, 2026 21:37
@yonromai yonromai marked this pull request as ready for review May 5, 2026 21:52
@yonromai yonromai requested review from dlwh and rjpower May 5, 2026 21:52
@yonromai
Copy link
Copy Markdown
Contributor Author

yonromai commented May 5, 2026

@claude Code review please.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 5, 2026

Claude finished @yonromai's task in 3m 16s —— View job


Code review

🤖 Reviewed d589513b5 against origin/main. The dependency split + Pallas helper change look reasonable and the validation evidence in the PR body is unusually thorough. Suggestions below are mostly polish, plus one consistency item worth addressing before merge.

What I think is well done

  • Migration scope: GPU extras flip to jax[cuda13]==0.10.0 while CPU/TPU/vLLM stay on jax==0.9.2. The ordered, package-scoped conflicts in the root pyproject.toml cleanly express this split: marin-levanter[gpu] vs each non-GPU consumer, and marin[gpu] vs the TPU-pinned levanter/fray paths. I cross-checked the workspace and the only jax==0.9.2/libtpu pins live in lib/levanter (tpu), lib/marin (tpu/cpu/vllm), and lib/fray (fray_tpu_test) — every one of those is covered.
  • nvidia-cublas>=13.2.0.9 and nvidia-nccl-cu13>=2.28.3 guards: comments cite the why (B200 cuBLAS warning, CW H100 NCCL all-to-all), and the cuBLAS guard is what forces Torch to stay on cu128. Calling that out in the body and in the source-map comments is the right level of bread-crumb.
  • Pallas helper: the new _named_sharding_uses_manual_axes is single-leading-underscore private and NamedSharding-only as advertised. Checking mesh.axis_types for AxisType.Manual is a clean replacement for the old HLO-only path that no longer fires for shard_map tracers under JAX 0.10. Existing tests still cover the externally observable behavior (test_shape_dtype_struct_for_benchmark_drops_manual_sharding_from_shard_map_tracer).
  • Docs: removing the manual CUDA toolkit / cuDNN install from local-gpu.md matches the new wheel-bundled CUDA 13 stack; the driver ≥580 requirement is consistent with CUDA 13.x runtime needs.

Suggestions (non-blocking unless flagged)

  1. Inconsistent core JAX upper bound (worth fixing). lib/marin/pyproject.toml#L29 widens to jax>=0.9.2,<0.11, but lib/levanter/pyproject.toml#L39 is still jax>=0.9.2 (no upper bound). Asymmetric upper bounds invite a future fork in non-locked installs where levanter pulls JAX 0.11 while marin caps at 0.10. Recommend mirroring <0.11 on levanter.
    Fix this →

  2. Duplicate jax/jaxlib/libtpu 0.9.2 pin block. The exact line ["jax==0.9.2", "jaxlib==0.9.2", "libtpu==0.0.38", ...] (or close to it) now lives in three workspace pyprojects (lib/marin, lib/levanter, lib/fray). While this is the temporary state, consider whether centralising via root override-dependencies or a comment pointer keeps drift from sneaking in next time the line moves. Not urgent.

  3. Comment on the widened core marin pin. The jax>=0.9.2,<0.11 widening (vs the old ==0.9.2) makes the resolved JAX version depend on which extra is selected. That's the design here, but it's the kind of thing that surprises future readers. A one-line # Floor at 0.9.2 (cpu/tpu/vllm); gpu extra pulls 0.10.0; cap at <0.11. near the dep would save future-you a git blame round-trip.

  4. vllm extra now explicitly pins jaxlib==0.9.2. Good defensive pin given the looser core. Confirm this was intentional and exported correctly to the vllm lockfile fork (the body says it was, but worth keeping in your validation checklist on rebases).

  5. Removed pjit.pjit_check_aval_sharding assertion in test_pallas_autotune_utils.py. Removing it is justified because the JAX 0.10 internal API may not raise the same way. The remaining assertions in that test still cover the observable property (manual aval ⇒ struct without sharding). Optional: replace the removed check with a lightweight public-API canary (e.g. assert local_x.aval.sharding is not None) to keep the same tracer-shape regression signal alive without depending on jax._src.pjit.

  6. _named_sharding_uses_manual_axes only checks AxisType.Manual. That's correct — Auto/Explicit axes don't constitute "manual". A partially-manual mesh (some axes Manual, some Explicit) returns True, which is the right call (any manual axis means we cannot trust the global sharding for benchmark lowering).

  7. Docs nit (coreweave.md): the new paragraph reads well; if you want, add a one-liner that operators should nvidia-smi first when triaging install failures on older nodes (CUDA 13 driver gating is the most common surprise). Optional.

  8. Lockfile churn (~3k lines): expected for the dual-JAX fork, but please keep uv lock --check on the rebase discipline list — silent forks are easy to introduce when adding/removing extras until the temporary conflicts block goes away.

Risk recap (matches your body, no new findings)

  • H100x8 canary + 3 device smokes (H100/GH200/B200) all green; no GH200/B200 training/profiler smoke yet — accepted residual risk.
  • The most recent CW Iris smoke run was cancelled before creating jobs, so the prior cw-ci-test failure (an Iris controller port-forward, not a JAX runtime issue) is not formally superseded. Worth a follow-up rerun, but not a JAX-CUDA-13 blocker.

Bottom line

Looks good. The only ask I'd push back on is the asymmetric jax>=0.9.2 core pin in levanter (#1). Everything else is optional polish.
· branch: agent/20260504-fix-5427

@yonromai yonromai merged commit e984831 into main May 6, 2026
40 of 43 checks passed
@yonromai yonromai deleted the agent/20260504-fix-5427 branch May 6, 2026 01:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants