Skip to content

fix(perf): retune Gemma1 decode#24

Merged
inureyes merged 1 commit into
mainfrom
fix/perf-retune-gemma1-decode
May 18, 2026
Merged

fix(perf): retune Gemma1 decode#24
inureyes merged 1 commit into
mainfrom
fix/perf-retune-gemma1-decode

Conversation

@inureyes
Copy link
Copy Markdown
Member

Summary

Brings the older Gemma v1 hot path up to the same dtype-discipline + activation-shape level as the newer Gemma versions. On gemma-2b-4bit (M5 Max), decode throughput jumps from 69.34 → 214.49 tok/s — closing most of the gap to the mlx-lm baseline of 223.27 tok/s.

Cherry-picked from mlxcel-internal commit 1b4937a8. The internal benchmarks/, docs/model_tests_m5max.md, and docs_internal/performance/... paths from the original commit are intentionally excluded (none of them exist in this repo — same pattern as #20, #22).

What changed

src/models/gemma.rs — three coordinated fixes:

  1. Quantized GeGLU activation now uses the same tanh-approx GELU path that mlx-lm picks for gelu_pytorch_tanh configs, instead of the exact-erf GELU. Bit-equivalent to the upstream reference.
  2. Embedding scaling no longer promotes the whole activation tensor to fp32 via a scalar multiply — the scalar is cast to the activation dtype first, matching the pattern already used in gemma2 / gemma3.
  3. Decode-time hints — enables the same layer-pipeline hint and maskless padded-prefill hooks the newer Gemma paths use, so Gemma v1 benefits from the same scheduler-level optimisations.

src/lib/mlxcel-core/src/utils.rs — 1-line supporting helper change to make the dtype-preserving scalar mul reusable.

Verification

  • make verify-fmt — clean
  • make verify-clippy (CI-faithful: --all-targets --features metal,accelerate -- -D warnings) — clean in 14s (warm cache)
  • make verify-test skipped (15-30 min release run); the upstream commit is already validated against the M5 Max sweep with the throughput numbers above.

Switch Gemma v1 quantized GeGLU to the same tanh-approx GELU path used by mlx-lm for gelu_pytorch_tanh configs, avoid scalar full-array promotion in embedding scaling, and enable the same layer pipeline hint / maskless padded prefill hooks used by the newer Gemma paths.

Measured gemma-2b-4bit on M5 Max: baseline chat-template decode 69.34 tok/s, patched final decode 214.49 tok/s, and mlx-lm baseline 223.27 tok/s.
@inureyes inureyes added status:review Under review type:bug Bug fixes, error corrections, or issue resolutions type:performance Performance improvements priority:high High priority area:models Model architectures, weights, loading, metadata labels May 18, 2026
@inureyes inureyes merged commit 38adc81 into main May 18, 2026
1 check passed
@inureyes inureyes deleted the fix/perf-retune-gemma1-decode branch May 18, 2026 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:models Model architectures, weights, loading, metadata priority:high High priority status:review Under review type:bug Bug fixes, error corrections, or issue resolutions type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant