@@ -431,3 +431,256 @@ Notes:
431431 - Relative: pallas is ` ~1.17x ` faster on fwd and ` ~1.13x ` faster on bwd.
432432- Note:
433433 - Loss values matched closely between backends in these runs.
434+
435+ ### 2026-02-21: v4-8 VMEM pressure triage + XLA custom VJP direction
436+ - Environment:
437+ - ` scripts/ray/dev_tpu.py --config infra/marin-us-central2.yaml --tpu-type v4-8 `
438+ - TPU: ` TPU v4 ` (4 local devices)
439+ - Core finding on VMEM failures:
440+ - At large failing configs (e.g. ` B=65536,H=512,V=128256 ` , ` b/h/v=1024/128/1024 ` ), ` fwd_only ` already fails with scoped VMEM OOM (` ~41.11M / 16M ` bytes), so forward is the first cliff.
441+ - Backward-only can fail at even higher VMEM (` ~47.05M / 16M ` ), but this is secondary at the tested boundaries.
442+ - Boundary probing:
443+ - ` B=512,H=512,v_block=768 ` : forward fails while direct backward can pass.
444+ - ` B=512,H=512,v_block=640 ` : both pass.
445+ - Pallas speed check versus XLA on v4:
446+ - For ` B=512,H=512,V=128256 ` forward-only, best Pallas config observed was around ` ~110k tok/s ` .
447+ - XLA forward-only on same shape was around ` ~572k tok/s ` (about ` 5x ` faster).
448+ - For value+grad on same shape, Pallas ` ~49k tok/s ` vs XLA streaming ` ~122k tok/s ` .
449+ - Conclusion:
450+ - For v4, Pallas path is constrained by VMEM and is not competitive at these tested settings.
451+ - We should favor an XLA streaming path with a custom backward on v4.
452+
453+ #### XLA streaming custom-VJP prototype result (v4-8)
454+ - Prototype behavior:
455+ - forward uses existing streaming CE (` linear_softmax_cross_entropy_loss_streaming ` )
456+ - backward manually streams over vocab blocks to avoid full autodiff materialization.
457+ - Measured on ` B=512,H=512,V=128256 ` :
458+ - builtin ` xla ` (` v_block=32768 ` ) value+grad: ` ~121.6k tok/s ` (` ~0.00421s ` )
459+ - custom streaming VJP (` v_block=32768 ` ): ` ~212.5k tok/s ` (` ~0.00241s ` )
460+ - custom streaming VJP (` v_block=8192 ` ): ` ~161k tok/s `
461+ - Prototype correctness spot-check (same env):
462+ - ` loss_builtin == loss_custom ` exactly in the sampled run.
463+ - gradient deltas:
464+ - ` gx_max_abs = 4.8828125e-04 `
465+ - ` gw_max_abs = 5.9604645e-08 `
466+ - ` gx_rel = 2.2317406e-03 `
467+ - ` gw_rel = 9.1245504e-08 `
468+
469+ #### Repo changes (this branch)
470+ - ` lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/xla.py `
471+ - Added ` _use_v4_custom_xla_vjp() ` gate: enable custom VJP only for TPU v4.
472+ - Added ` _linear_softmax_cross_entropy_loss_streaming_custom_vjp(...) ` with manual streaming backward:
473+ - computes blockwise ` delta = (dL + dLSE)*prob - dL*one_hot `
474+ - applies soft-cap derivative when enabled
475+ - accumulates ` dx ` and writes ` dw ` blockwise.
476+ - ` linear_softmax_cross_entropy_loss_xla(...) ` now dispatches to custom VJP on v4; other backends keep existing behavior.
477+ - ` lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py `
478+ - Added ` test_v4_custom_xla_vjp_gate `
479+ - Added ` test_xla_streaming_custom_vjp_grad_matches_streaming_autodiff `
480+ - Local test run (` -k 'xla or custom_vjp or gate' ` ): ` 3 passed, 1 skipped ` .
481+
482+ #### In-repo TPU validation after patch (v4-8)
483+ - API bench (` linear_softmax_cross_entropy_loss_xla ` through fused loss API):
484+ - shape ` B=512,H=512,V=128256 ` (` batch=1,pos=512 ` in bench script)
485+ - fwd ` steady_time_s=0.00090334 ` (` ~566.8k tok/s ` )
486+ - bwd ` bwd_steady_time_s=0.00243601 ` (` ~210.2k tok/s ` )
487+ - Direct head-to-head on the same TPU run (` v_block=32768 ` , value+grad):
488+ - API path (now v4 custom VJP): ` 0.002495s ` (` ~205.2k tok/s ` )
489+ - baseline streaming autodiff (` linear_softmax_cross_entropy_loss_streaming ` ): ` 0.004175s ` (` ~122.6k tok/s ` )
490+ - speedup: ` ~1.67x ` for backward-inclusive step time.
491+ - Correctness spot-check on TPU after integration:
492+ - ` _use_v4_custom_xla_vjp() ` returned ` True ` on ` TPU v4 ` .
493+ - max abs diff versus streaming-autodiff gradient at sample shape (` B=128,H=128,V=4096 ` ):
494+ - ` gx_max_abs = 1.220703125e-04 `
495+ - ` gw_max_abs = 1.220703125e-04 `
496+
497+ ### 2026-02-21: v5p-8 sanity run (us-central1 dev TPU)
498+ - Environment:
499+ - ` scripts/ray/dev_tpu.py --config infra/marin-us-central1.yaml --tpu-type v5p-8 `
500+ - TPU: ` tpu v5 ` (` 4 ` local devices on this slice)
501+ - API bench (` implementation=xla ` , same shape as v4 check):
502+ - shape ` B=512,H=512,V=128256 ` (` batch=1,pos=512 ` )
503+ - fwd ` steady_time_s=0.000687716 ` (` ~744.5k tok/s ` )
504+ - bwd ` bwd_steady_time_s=0.00283721 ` (` ~180.5k tok/s ` )
505+ - Direct API-vs-baseline check (` v_block=32768 ` , value+grad):
506+ - ` _use_v4_custom_xla_vjp() ` returned ` False ` (expected on v5p/v5).
507+ - API (` linear_softmax_cross_entropy_loss_xla ` ): ` 0.00290257s ` (` ~176.4k tok/s ` )
508+ - baseline (` linear_softmax_cross_entropy_loss_streaming ` autodiff): ` 0.00289034s ` (` ~177.1k tok/s ` )
509+ - delta is negligible (` ~0.4% ` ), confirming the v4-only gate preserves v5 behavior.
510+
511+ ### 2026-02-21: What XLA is actually implementing (and why tile size is hard to extract)
512+ - HLO inspection on v4 (` linear_softmax_cross_entropy_loss_xla ` , ` B=512,H=512,V=128256 ` ) shows:
513+ - explicit while-loop over vocab blocks with trip count ` 4 ` (` 131072 padded vocab / 32768 block ` ).
514+ - per-iteration ` dynamic-slice ` of ` w ` with ` dynamic_slice_sizes={512,32768} ` .
515+ - one block GEMM-equivalent op per iteration in unoptimized HLO:
516+ - ` dot(Arg_1.9, dynamic_slice.1) ` in ` closed_call.7 ` .
517+ - masked logits (` where ` with ` -inf ` ), per-block ` reduce_max ` / ` exp ` / ` reduce_sum ` , ` logaddexp ` accumulation, and label-logit gather.
518+ - In optimized TPU HLO dump, that dot is canonicalized into a convolution-form op:
519+ - ` convolution(...), dim_labels=bf_io->bf ` with metadata tracing back to ` dot_general ` .
520+ - This is why searching for ` dot( ` in late dumps is often misleading.
521+ - What is visible as “tiling”:
522+ - layout annotations such as ` bf16[512,131072]{1,0:T(8,128)(2,1)} ` and ` f32[512,32768]{1,0:T(8,128)} ` .
523+ - These are layout/packing tiles (memory layout tiling), not a direct “MXU kernel tile size” parameter.
524+ - What is ** not** directly exposed:
525+ - the backend-selected microkernel tile/schedule/unrolling used by TPU codegen/libtpu for the matmul-like op.
526+ - There is no stable single field in emitted HLO that says “the matmul tile size is X by Y”.
527+ - Practical conclusion:
528+ - We can reliably recover ** algorithmic blocking** (` v_block_size=32768 ` ) and loop structure from HLO.
529+ - We can see ** layout tile annotations** (` T(8,128) ` etc.).
530+ - We generally cannot recover a single definitive backend GEMM micro-tile from user-facing HLO text alone.
531+
532+ ### 2026-02-21: Forced custom-VJP trial on v5p-8
533+ - Goal:
534+ - Evaluate enabling the new custom VJP on v5p (currently gated off in code) by directly calling
535+ ` _linear_softmax_cross_entropy_loss_streaming_custom_vjp(...) ` .
536+ - Environment:
537+ - ` marin-us-central1 ` dev TPU ` v5p-8 ` , device kind reported as ` TPU v5 ` .
538+ - Gate check: ` _use_v4_custom_xla_vjp() == False ` (expected).
539+ - Comparison setup:
540+ - same value+grad benchmark, ` dtype=float32 ` , ` v_block_size=32768 ` .
541+ - compared:
542+ - ` api_xla ` (` linear_softmax_cross_entropy_loss_xla ` )
543+ - ` custom_vjp ` (forced private custom-vjp call)
544+ - ` stream_autodiff ` (` linear_softmax_cross_entropy_loss_streaming ` with AD)
545+ - Results:
546+ - Shape ` B=512,H=512,V=128256 ` :
547+ - ` api_xla ` : ` 0.00290897s ` (` ~176.0k tok/s ` )
548+ - ` custom_vjp ` : ` 0.00208500s ` (` ~245.6k tok/s ` ) ** (+39.5% vs api_xla)**
549+ - ` stream_autodiff ` : ` 0.00315458s ` (` ~162.3k tok/s ` )
550+ - Shape ` B=8192,H=4096,V=128256 ` :
551+ - ` api_xla ` : ` 0.09183749s ` (` ~89.2k tok/s ` )
552+ - ` custom_vjp ` : ` 0.09866696s ` (` ~83.0k tok/s ` ) ** (-6.9% vs api_xla)**
553+ - ` stream_autodiff ` : ` 0.09211473s ` (` ~88.9k tok/s ` )
554+ - Interpretation:
555+ - For v5p, forced custom VJP is ** shape-dependent** : faster at smaller shape, slower at larger ` H=4096 ` shape.
556+ - This supports keeping the default v4-only gate for now unless we add shape-based gating/autotune.
557+ - Correctness spot-check on v5p (small shape ` B=512,H=512,V=128256 ` ):
558+ - max abs grad diff (api vs forced custom):
559+ - ` gx_max_abs = 6.103515625e-05 ` (` gx_rel = 5.78e-03 ` )
560+ - ` gw_max_abs = 3.0517578125e-05 ` (` gw_rel = 3.18e-03 ` )
561+
562+ ### 2026-02-21: v5p question - streaming custom VJP vs pallas
563+ - Direct value+grad head-to-head on ` v5p-8 ` :
564+ - shape ` B=512,H=512,V=128256 ` :
565+ - ` pallas_tpu ` (infer): ` 0.00337185s ` (` ~151.8k tok/s ` )
566+ - ` streaming_custom_vjp ` : ` 0.00169963s ` (` ~301.2k tok/s ` )
567+ - result: streaming custom VJP is about ` 1.98x ` faster.
568+ - shape ` B=8192,H=4096,V=128256 ` :
569+ - ` pallas_tpu ` failed scoped VMEM OOM in JVP path (` 39.04M / 16.00M ` ).
570+ - ` streaming_custom_vjp ` succeeded at ` 0.09830s ` (` ~83.3k tok/s ` ).
571+
572+ ### 2026-02-21: XLA default switched to custom VJP
573+ - Code change:
574+ - ` linear_softmax_cross_entropy_loss_xla(...) ` now unconditionally dispatches to
575+ ` _linear_softmax_cross_entropy_loss_streaming_custom_vjp(...) ` .
576+ - Removed the v4-only gate from active dispatch.
577+ - Tests:
578+ - Removed gate-specific test and kept custom-VJP grad parity test.
579+ - ` pytest -k 'xla or custom_vjp' ` in ` lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py ` :
580+ - ` 2 passed, 1 skipped ` .
581+ - v5p sanity after change (` B=512,H=512,V=128256 ` ):
582+ - ` api_xla ` : ` 0.00168887s ` (` ~303.2k tok/s ` )
583+ - forced custom-vjp call: ` 0.00169305s ` (` ~302.4k tok/s ` )
584+ - confirms API now uses the same path.
585+
586+ ### 2026-02-21: Default backend policy update
587+ - Changed fused CE API default implementation order to always prefer ` xla ` , even when ` pallas_tpu ` is importable.
588+ - file: ` lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py `
589+ - ` pallas_tpu ` remains available when explicitly requested via ` implementation='pallas_tpu' ` .
590+ - Validation:
591+ - ` uv run --package levanter --group test pytest lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py `
592+ - result: ` 14 passed, 3 skipped ` .
593+
594+ ### 2026-02-21: Follow-up v5p check (where can pallas still win?)
595+ - Additional large-shape head-to-head (` v5p-8 ` , value+grad):
596+ - shape ` B=32768,H=4096,V=128256 ` , ` v_block=32768 `
597+ - ` pallas_tpu ` : failed scoped VMEM OOM (` 39.04M / 16.00M ` )
598+ - ` streaming_custom_vjp ` : succeeded at ` 0.39032s ` (` ~83.95k tok/s ` )
599+ - Combined with earlier same-session results:
600+ - ` B=512,H=512,V=128256 ` : custom-vjp ` ~301k tok/s ` vs pallas ` ~152k tok/s `
601+ - ` B=8192,H=4096,V=128256 ` : pallas OOM, custom-vjp ` ~83k tok/s `
602+ - Practical takeaway on this env:
603+ - With current scoped VMEM limits, pallas is not competitive for these tested v5p backward-inclusive workloads.
604+
605+ ### 2026-02-21: v5p rerun with higher scoped VMEM limit
606+ - Reran with:
607+ - ` LIBTPU_INIT_ARGS=--xla_tpu_scoped_vmem_limit_kib=50000 `
608+ - same value+grad benchmark, ` v_block=32768 ` .
609+ - Results:
610+ - ` B=512,H=512,V=128256 ` :
611+ - ` pallas_tpu ` : ` 0.00336478s ` (` ~152.2k tok/s ` )
612+ - ` streaming_custom_vjp ` : ` 0.00189534s ` (` ~270.1k tok/s ` )
613+ - ` B=8192,H=4096,V=128256 ` :
614+ - ` pallas_tpu ` : ` 0.116103s ` (` ~70.6k tok/s ` )
615+ - ` streaming_custom_vjp ` : ` 0.0923914s ` (` ~88.7k tok/s ` )
616+ - ` B=32768,H=4096,V=128256 ` :
617+ - ` pallas_tpu ` : ` 0.436443s ` (` ~75.1k tok/s ` )
618+ - ` streaming_custom_vjp ` : ` 0.365443s ` (` ~89.7k tok/s ` )
619+ - Conclusion:
620+ - Raising scoped VMEM lets pallas run large shapes again, but streaming custom VJP remains faster on all tested v5p backward-inclusive shapes.
621+
622+ ### 2026-02-21: Tokamax vs xla(custom-vjp) vs pallas on v5e-8/v6e-8 (eu-west4)
623+ - Request:
624+ - compare Tokamax kernel vs our new default ` xla ` path (streaming custom VJP) and our ` pallas_tpu ` kernel.
625+ - target TPUs:
626+ - ` v5e-8 ` in ` europe-west4-b ` (` infra/marin-eu-west4.yaml ` )
627+ - ` v6e-8 ` in ` europe-west4-a ` (` infra/marin-eu-west4-a.yaml ` )
628+ - shape used for all runs: ` B=8192, H=4096, V=128256 ` .
629+
630+ #### Infra notes
631+ - ` v5e-8 ` allocation had intermittent autoscaler/preemption churn:
632+ - initial attempts timed out waiting for actor start.
633+ - one successful allocation was later terminated (` ActorDiedError ` , node SIGTERM) and had to be reacquired.
634+ - ` v6e-8 ` allocation was stable in this session.
635+
636+ #### Tokamax install/runtime notes
637+ - A dedicated Tokamax env was used on each TPU VM:
638+ - ` uv venv .venv_tokamax --python 3.11 `
639+ - ` uv pip install tokamax `
640+ - ` uv pip install 'jax[tpu]==0.9.0' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html `
641+ - This produced:
642+ - ` jax==0.9.0 ` , ` jaxlib==0.9.0 ` , ` libtpu==0.0.34 ` for Tokamax runs.
643+ - Levanter xla/pallas runs stayed on project-locked env (` jax==0.8.0 ` , ` jaxlib==0.8.0 ` , ` libtpu==0.0.24 ` ).
644+
645+ #### Dtype compatibility findings (Tokamax ` mosaic_tpu ` )
646+ - ` bf16 ` failed on both ` v5e ` and ` v6e ` with Pallas verifier error:
647+ - ` 'tpu.matmul' op Expected matmul acc to be 32-bit `
648+ - ` float32 ` runs were successful for Tokamax on both ` v5e ` and ` v6e ` .
649+ - Per follow-up request, comparison was done in a shared working dtype (` float32 ` ).
650+
651+ #### Float32 comparison (value+grad, ` B=8192,H=4096,V=128256 ` )
652+ - v5e-8 (` europe-west4-b ` ):
653+ - ` xla ` (custom-vjp default):
654+ - fwd: ` 128,056 tok/s `
655+ - bwd: ` 36,612 tok/s `
656+ - combined (harmonic): ` 28,472 tok/s `
657+ - ` pallas_tpu ` (` block-sizes=infer ` ):
658+ - fwd: ` 128,223 tok/s `
659+ - bwd: ` 25,737 tok/s `
660+ - combined: ` 21,435 tok/s `
661+ - Tokamax ` mosaic_tpu ` :
662+ - fwd: ` 11,036 tok/s `
663+ - bwd: ` 22,999 tok/s `
664+ - combined: ` 7,458 tok/s `
665+ - v6e-8 (` europe-west4-a ` ):
666+ - ` xla ` (custom-vjp default):
667+ - fwd: ` 259,456 tok/s `
668+ - bwd: ` 86,501 tok/s `
669+ - combined: ` 64,873 tok/s `
670+ - ` pallas_tpu ` (` block-sizes=infer ` ):
671+ - fwd: ` 243,238 tok/s `
672+ - bwd: ` 53,753 tok/s `
673+ - combined: ` 44,024 tok/s `
674+ - Tokamax ` mosaic_tpu ` :
675+ - fwd: ` 11,451 tok/s `
676+ - bwd: ` 76,094 tok/s `
677+ - combined: ` 9,953 tok/s `
678+
679+ #### Extra bf16 context (our kernels)
680+ - On both TPUs, our bf16 ` xla ` /` pallas ` runs completed; ` xla ` remained ahead on combined throughput.
681+ - Tokamax bf16 remained blocked by the verifier error above.
682+
683+ #### Bottom line
684+ - In the only shared working dtype (` float32 ` ), our ` xla ` custom-vjp path is clearly fastest on combined throughput on both ` v5e-8 ` and ` v6e-8 ` .
685+ - ` pallas_tpu ` remains competitive on forward but trails on backward, so combined is below ` xla ` .
686+ - Tokamax ` mosaic_tpu ` is not competitive in this setup and cannot currently run bf16 on these TPUs due the matmul-accumulator verification failure.
0 commit comments