Skip to content

Commit b85633e

Browse files
authored
Make fused CE default to XLA custom VJP on TPU (#2951)
## Summary - make fused CE default implementation always prefer `xla`, even when `pallas_tpu` is importable - switch `linear_softmax_cross_entropy_loss_xla` to use the streaming custom VJP path unconditionally - remove v4-gate-only test and keep custom-VJP gradient parity coverage - document v4/v5p benchmarking results and backend policy update in `.agents/projects/linear_ce_loss.md` ## Validation - `uv run --package levanter --group test pytest lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py` - result: `14 passed, 3 skipped` - `./infra/pre-commit.py --all-files` - all checks pass except the known pyrefly hidden-worktree exclude issue (no actionable type errors emitted) ## Notes - `pallas_tpu` remains available via explicit `implementation='pallas_tpu'` selection.
1 parent 0c4db8f commit b85633e

File tree

4 files changed

+480
-10
lines changed

4 files changed

+480
-10
lines changed

.agents/projects/linear_ce_loss.md

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,256 @@ Notes:
431431
- Relative: pallas is `~1.17x` faster on fwd and `~1.13x` faster on bwd.
432432
- Note:
433433
- Loss values matched closely between backends in these runs.
434+
435+
### 2026-02-21: v4-8 VMEM pressure triage + XLA custom VJP direction
436+
- Environment:
437+
- `scripts/ray/dev_tpu.py --config infra/marin-us-central2.yaml --tpu-type v4-8`
438+
- TPU: `TPU v4` (4 local devices)
439+
- Core finding on VMEM failures:
440+
- At large failing configs (e.g. `B=65536,H=512,V=128256`, `b/h/v=1024/128/1024`), `fwd_only` already fails with scoped VMEM OOM (`~41.11M / 16M` bytes), so forward is the first cliff.
441+
- Backward-only can fail at even higher VMEM (`~47.05M / 16M`), but this is secondary at the tested boundaries.
442+
- Boundary probing:
443+
- `B=512,H=512,v_block=768`: forward fails while direct backward can pass.
444+
- `B=512,H=512,v_block=640`: both pass.
445+
- Pallas speed check versus XLA on v4:
446+
- For `B=512,H=512,V=128256` forward-only, best Pallas config observed was around `~110k tok/s`.
447+
- XLA forward-only on same shape was around `~572k tok/s` (about `5x` faster).
448+
- For value+grad on same shape, Pallas `~49k tok/s` vs XLA streaming `~122k tok/s`.
449+
- Conclusion:
450+
- For v4, Pallas path is constrained by VMEM and is not competitive at these tested settings.
451+
- We should favor an XLA streaming path with a custom backward on v4.
452+
453+
#### XLA streaming custom-VJP prototype result (v4-8)
454+
- Prototype behavior:
455+
- forward uses existing streaming CE (`linear_softmax_cross_entropy_loss_streaming`)
456+
- backward manually streams over vocab blocks to avoid full autodiff materialization.
457+
- Measured on `B=512,H=512,V=128256`:
458+
- builtin `xla` (`v_block=32768`) value+grad: `~121.6k tok/s` (`~0.00421s`)
459+
- custom streaming VJP (`v_block=32768`): `~212.5k tok/s` (`~0.00241s`)
460+
- custom streaming VJP (`v_block=8192`): `~161k tok/s`
461+
- Prototype correctness spot-check (same env):
462+
- `loss_builtin == loss_custom` exactly in the sampled run.
463+
- gradient deltas:
464+
- `gx_max_abs = 4.8828125e-04`
465+
- `gw_max_abs = 5.9604645e-08`
466+
- `gx_rel = 2.2317406e-03`
467+
- `gw_rel = 9.1245504e-08`
468+
469+
#### Repo changes (this branch)
470+
- `lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/xla.py`
471+
- Added `_use_v4_custom_xla_vjp()` gate: enable custom VJP only for TPU v4.
472+
- Added `_linear_softmax_cross_entropy_loss_streaming_custom_vjp(...)` with manual streaming backward:
473+
- computes blockwise `delta = (dL + dLSE)*prob - dL*one_hot`
474+
- applies soft-cap derivative when enabled
475+
- accumulates `dx` and writes `dw` blockwise.
476+
- `linear_softmax_cross_entropy_loss_xla(...)` now dispatches to custom VJP on v4; other backends keep existing behavior.
477+
- `lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py`
478+
- Added `test_v4_custom_xla_vjp_gate`
479+
- Added `test_xla_streaming_custom_vjp_grad_matches_streaming_autodiff`
480+
- Local test run (`-k 'xla or custom_vjp or gate'`): `3 passed, 1 skipped`.
481+
482+
#### In-repo TPU validation after patch (v4-8)
483+
- API bench (`linear_softmax_cross_entropy_loss_xla` through fused loss API):
484+
- shape `B=512,H=512,V=128256` (`batch=1,pos=512` in bench script)
485+
- fwd `steady_time_s=0.00090334` (`~566.8k tok/s`)
486+
- bwd `bwd_steady_time_s=0.00243601` (`~210.2k tok/s`)
487+
- Direct head-to-head on the same TPU run (`v_block=32768`, value+grad):
488+
- API path (now v4 custom VJP): `0.002495s` (`~205.2k tok/s`)
489+
- baseline streaming autodiff (`linear_softmax_cross_entropy_loss_streaming`): `0.004175s` (`~122.6k tok/s`)
490+
- speedup: `~1.67x` for backward-inclusive step time.
491+
- Correctness spot-check on TPU after integration:
492+
- `_use_v4_custom_xla_vjp()` returned `True` on `TPU v4`.
493+
- max abs diff versus streaming-autodiff gradient at sample shape (`B=128,H=128,V=4096`):
494+
- `gx_max_abs = 1.220703125e-04`
495+
- `gw_max_abs = 1.220703125e-04`
496+
497+
### 2026-02-21: v5p-8 sanity run (us-central1 dev TPU)
498+
- Environment:
499+
- `scripts/ray/dev_tpu.py --config infra/marin-us-central1.yaml --tpu-type v5p-8`
500+
- TPU: `tpu v5` (`4` local devices on this slice)
501+
- API bench (`implementation=xla`, same shape as v4 check):
502+
- shape `B=512,H=512,V=128256` (`batch=1,pos=512`)
503+
- fwd `steady_time_s=0.000687716` (`~744.5k tok/s`)
504+
- bwd `bwd_steady_time_s=0.00283721` (`~180.5k tok/s`)
505+
- Direct API-vs-baseline check (`v_block=32768`, value+grad):
506+
- `_use_v4_custom_xla_vjp()` returned `False` (expected on v5p/v5).
507+
- API (`linear_softmax_cross_entropy_loss_xla`): `0.00290257s` (`~176.4k tok/s`)
508+
- baseline (`linear_softmax_cross_entropy_loss_streaming` autodiff): `0.00289034s` (`~177.1k tok/s`)
509+
- delta is negligible (`~0.4%`), confirming the v4-only gate preserves v5 behavior.
510+
511+
### 2026-02-21: What XLA is actually implementing (and why tile size is hard to extract)
512+
- HLO inspection on v4 (`linear_softmax_cross_entropy_loss_xla`, `B=512,H=512,V=128256`) shows:
513+
- explicit while-loop over vocab blocks with trip count `4` (`131072 padded vocab / 32768 block`).
514+
- per-iteration `dynamic-slice` of `w` with `dynamic_slice_sizes={512,32768}`.
515+
- one block GEMM-equivalent op per iteration in unoptimized HLO:
516+
- `dot(Arg_1.9, dynamic_slice.1)` in `closed_call.7`.
517+
- masked logits (`where` with `-inf`), per-block `reduce_max` / `exp` / `reduce_sum`, `logaddexp` accumulation, and label-logit gather.
518+
- In optimized TPU HLO dump, that dot is canonicalized into a convolution-form op:
519+
- `convolution(...), dim_labels=bf_io->bf` with metadata tracing back to `dot_general`.
520+
- This is why searching for `dot(` in late dumps is often misleading.
521+
- What is visible as “tiling”:
522+
- layout annotations such as `bf16[512,131072]{1,0:T(8,128)(2,1)}` and `f32[512,32768]{1,0:T(8,128)}`.
523+
- These are layout/packing tiles (memory layout tiling), not a direct “MXU kernel tile size” parameter.
524+
- What is **not** directly exposed:
525+
- the backend-selected microkernel tile/schedule/unrolling used by TPU codegen/libtpu for the matmul-like op.
526+
- There is no stable single field in emitted HLO that says “the matmul tile size is X by Y”.
527+
- Practical conclusion:
528+
- We can reliably recover **algorithmic blocking** (`v_block_size=32768`) and loop structure from HLO.
529+
- We can see **layout tile annotations** (`T(8,128)` etc.).
530+
- We generally cannot recover a single definitive backend GEMM micro-tile from user-facing HLO text alone.
531+
532+
### 2026-02-21: Forced custom-VJP trial on v5p-8
533+
- Goal:
534+
- Evaluate enabling the new custom VJP on v5p (currently gated off in code) by directly calling
535+
`_linear_softmax_cross_entropy_loss_streaming_custom_vjp(...)`.
536+
- Environment:
537+
- `marin-us-central1` dev TPU `v5p-8`, device kind reported as `TPU v5`.
538+
- Gate check: `_use_v4_custom_xla_vjp() == False` (expected).
539+
- Comparison setup:
540+
- same value+grad benchmark, `dtype=float32`, `v_block_size=32768`.
541+
- compared:
542+
- `api_xla` (`linear_softmax_cross_entropy_loss_xla`)
543+
- `custom_vjp` (forced private custom-vjp call)
544+
- `stream_autodiff` (`linear_softmax_cross_entropy_loss_streaming` with AD)
545+
- Results:
546+
- Shape `B=512,H=512,V=128256`:
547+
- `api_xla`: `0.00290897s` (`~176.0k tok/s`)
548+
- `custom_vjp`: `0.00208500s` (`~245.6k tok/s`) **(+39.5% vs api_xla)**
549+
- `stream_autodiff`: `0.00315458s` (`~162.3k tok/s`)
550+
- Shape `B=8192,H=4096,V=128256`:
551+
- `api_xla`: `0.09183749s` (`~89.2k tok/s`)
552+
- `custom_vjp`: `0.09866696s` (`~83.0k tok/s`) **(-6.9% vs api_xla)**
553+
- `stream_autodiff`: `0.09211473s` (`~88.9k tok/s`)
554+
- Interpretation:
555+
- For v5p, forced custom VJP is **shape-dependent**: faster at smaller shape, slower at larger `H=4096` shape.
556+
- This supports keeping the default v4-only gate for now unless we add shape-based gating/autotune.
557+
- Correctness spot-check on v5p (small shape `B=512,H=512,V=128256`):
558+
- max abs grad diff (api vs forced custom):
559+
- `gx_max_abs = 6.103515625e-05` (`gx_rel = 5.78e-03`)
560+
- `gw_max_abs = 3.0517578125e-05` (`gw_rel = 3.18e-03`)
561+
562+
### 2026-02-21: v5p question - streaming custom VJP vs pallas
563+
- Direct value+grad head-to-head on `v5p-8`:
564+
- shape `B=512,H=512,V=128256`:
565+
- `pallas_tpu` (infer): `0.00337185s` (`~151.8k tok/s`)
566+
- `streaming_custom_vjp`: `0.00169963s` (`~301.2k tok/s`)
567+
- result: streaming custom VJP is about `1.98x` faster.
568+
- shape `B=8192,H=4096,V=128256`:
569+
- `pallas_tpu` failed scoped VMEM OOM in JVP path (`39.04M / 16.00M`).
570+
- `streaming_custom_vjp` succeeded at `0.09830s` (`~83.3k tok/s`).
571+
572+
### 2026-02-21: XLA default switched to custom VJP
573+
- Code change:
574+
- `linear_softmax_cross_entropy_loss_xla(...)` now unconditionally dispatches to
575+
`_linear_softmax_cross_entropy_loss_streaming_custom_vjp(...)`.
576+
- Removed the v4-only gate from active dispatch.
577+
- Tests:
578+
- Removed gate-specific test and kept custom-VJP grad parity test.
579+
- `pytest -k 'xla or custom_vjp'` in `lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py`:
580+
- `2 passed, 1 skipped`.
581+
- v5p sanity after change (`B=512,H=512,V=128256`):
582+
- `api_xla`: `0.00168887s` (`~303.2k tok/s`)
583+
- forced custom-vjp call: `0.00169305s` (`~302.4k tok/s`)
584+
- confirms API now uses the same path.
585+
586+
### 2026-02-21: Default backend policy update
587+
- Changed fused CE API default implementation order to always prefer `xla`, even when `pallas_tpu` is importable.
588+
- file: `lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py`
589+
- `pallas_tpu` remains available when explicitly requested via `implementation='pallas_tpu'`.
590+
- Validation:
591+
- `uv run --package levanter --group test pytest lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py`
592+
- result: `14 passed, 3 skipped`.
593+
594+
### 2026-02-21: Follow-up v5p check (where can pallas still win?)
595+
- Additional large-shape head-to-head (`v5p-8`, value+grad):
596+
- shape `B=32768,H=4096,V=128256`, `v_block=32768`
597+
- `pallas_tpu`: failed scoped VMEM OOM (`39.04M / 16.00M`)
598+
- `streaming_custom_vjp`: succeeded at `0.39032s` (`~83.95k tok/s`)
599+
- Combined with earlier same-session results:
600+
- `B=512,H=512,V=128256`: custom-vjp `~301k tok/s` vs pallas `~152k tok/s`
601+
- `B=8192,H=4096,V=128256`: pallas OOM, custom-vjp `~83k tok/s`
602+
- Practical takeaway on this env:
603+
- With current scoped VMEM limits, pallas is not competitive for these tested v5p backward-inclusive workloads.
604+
605+
### 2026-02-21: v5p rerun with higher scoped VMEM limit
606+
- Reran with:
607+
- `LIBTPU_INIT_ARGS=--xla_tpu_scoped_vmem_limit_kib=50000`
608+
- same value+grad benchmark, `v_block=32768`.
609+
- Results:
610+
- `B=512,H=512,V=128256`:
611+
- `pallas_tpu`: `0.00336478s` (`~152.2k tok/s`)
612+
- `streaming_custom_vjp`: `0.00189534s` (`~270.1k tok/s`)
613+
- `B=8192,H=4096,V=128256`:
614+
- `pallas_tpu`: `0.116103s` (`~70.6k tok/s`)
615+
- `streaming_custom_vjp`: `0.0923914s` (`~88.7k tok/s`)
616+
- `B=32768,H=4096,V=128256`:
617+
- `pallas_tpu`: `0.436443s` (`~75.1k tok/s`)
618+
- `streaming_custom_vjp`: `0.365443s` (`~89.7k tok/s`)
619+
- Conclusion:
620+
- Raising scoped VMEM lets pallas run large shapes again, but streaming custom VJP remains faster on all tested v5p backward-inclusive shapes.
621+
622+
### 2026-02-21: Tokamax vs xla(custom-vjp) vs pallas on v5e-8/v6e-8 (eu-west4)
623+
- Request:
624+
- compare Tokamax kernel vs our new default `xla` path (streaming custom VJP) and our `pallas_tpu` kernel.
625+
- target TPUs:
626+
- `v5e-8` in `europe-west4-b` (`infra/marin-eu-west4.yaml`)
627+
- `v6e-8` in `europe-west4-a` (`infra/marin-eu-west4-a.yaml`)
628+
- shape used for all runs: `B=8192, H=4096, V=128256`.
629+
630+
#### Infra notes
631+
- `v5e-8` allocation had intermittent autoscaler/preemption churn:
632+
- initial attempts timed out waiting for actor start.
633+
- one successful allocation was later terminated (`ActorDiedError`, node SIGTERM) and had to be reacquired.
634+
- `v6e-8` allocation was stable in this session.
635+
636+
#### Tokamax install/runtime notes
637+
- A dedicated Tokamax env was used on each TPU VM:
638+
- `uv venv .venv_tokamax --python 3.11`
639+
- `uv pip install tokamax`
640+
- `uv pip install 'jax[tpu]==0.9.0' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html`
641+
- This produced:
642+
- `jax==0.9.0`, `jaxlib==0.9.0`, `libtpu==0.0.34` for Tokamax runs.
643+
- Levanter xla/pallas runs stayed on project-locked env (`jax==0.8.0`, `jaxlib==0.8.0`, `libtpu==0.0.24`).
644+
645+
#### Dtype compatibility findings (Tokamax `mosaic_tpu`)
646+
- `bf16` failed on both `v5e` and `v6e` with Pallas verifier error:
647+
- `'tpu.matmul' op Expected matmul acc to be 32-bit`
648+
- `float32` runs were successful for Tokamax on both `v5e` and `v6e`.
649+
- Per follow-up request, comparison was done in a shared working dtype (`float32`).
650+
651+
#### Float32 comparison (value+grad, `B=8192,H=4096,V=128256`)
652+
- v5e-8 (`europe-west4-b`):
653+
- `xla` (custom-vjp default):
654+
- fwd: `128,056 tok/s`
655+
- bwd: `36,612 tok/s`
656+
- combined (harmonic): `28,472 tok/s`
657+
- `pallas_tpu` (`block-sizes=infer`):
658+
- fwd: `128,223 tok/s`
659+
- bwd: `25,737 tok/s`
660+
- combined: `21,435 tok/s`
661+
- Tokamax `mosaic_tpu`:
662+
- fwd: `11,036 tok/s`
663+
- bwd: `22,999 tok/s`
664+
- combined: `7,458 tok/s`
665+
- v6e-8 (`europe-west4-a`):
666+
- `xla` (custom-vjp default):
667+
- fwd: `259,456 tok/s`
668+
- bwd: `86,501 tok/s`
669+
- combined: `64,873 tok/s`
670+
- `pallas_tpu` (`block-sizes=infer`):
671+
- fwd: `243,238 tok/s`
672+
- bwd: `53,753 tok/s`
673+
- combined: `44,024 tok/s`
674+
- Tokamax `mosaic_tpu`:
675+
- fwd: `11,451 tok/s`
676+
- bwd: `76,094 tok/s`
677+
- combined: `9,953 tok/s`
678+
679+
#### Extra bf16 context (our kernels)
680+
- On both TPUs, our bf16 `xla`/`pallas` runs completed; `xla` remained ahead on combined throughput.
681+
- Tokamax bf16 remained blocked by the verifier error above.
682+
683+
#### Bottom line
684+
- In the only shared working dtype (`float32`), our `xla` custom-vjp path is clearly fastest on combined throughput on both `v5e-8` and `v6e-8`.
685+
- `pallas_tpu` remains competitive on forward but trails on backward, so combined is below `xla`.
686+
- Tokamax `mosaic_tpu` is not competitive in this setup and cannot currently run bf16 on these TPUs due the matmul-accumulator verification failure.

lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .pallas_tpu import PallasUnsupportedError, linear_softmax_cross_entropy_loss_pallas
3434

3535
IMPLEMENTATIONS["pallas_tpu"] = linear_softmax_cross_entropy_loss_pallas
36-
_DEFAULT_IMPLEMENTATION = ("pallas_tpu",) + _DEFAULT_IMPLEMENTATION
3736
except ImportError:
3837
PallasUnsupportedError = NotImplementedError # type: ignore[assignment]
3938

0 commit comments

Comments
 (0)