Skip to content

Optimize Arm f32 GEMM using FMLA (by element) #679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 27, 2025

Conversation

robertknight
Copy link
Owner

@robertknight robertknight commented Apr 23, 2025

Arm supports an FMLA instruction variant 1 which multiples one lane of a
vector by values in another vector and then accumulates into a destination.
This enables reducing the number of loads from A in each iteration of the GEMM
microkernel by a factor of 4. Instead of broadcasting one scalar value from A at
a time into a register used as an FMA operand, we can load a vector of 4
elements from A and then each FMA specifies which of the 4 elements to use.

The way this is implemented is by using broadcast_lane + mul_add generic
operations (vdupq_laneq_f32 followed by fmla) which LLVM will fuse into a
single fmla.


Future work:

  • Apply the same operation for int8 GEMM via UDOT-by-element (probably a separate PR)

Footnotes

  1. https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLA--by-element---Floating-point-fused-Multiply-Add-to-accumulator--by-element--?lang=en

@robertknight
Copy link
Owner Author

There is a vfmaq_laneq_f32 intrinsic which exposes the fused broadcast + FMA in a more explicit way. However using it generates worse code than using vdupq_laneq_f32 + vfmaq_f32.

The broadcast lane + FMA fusion is convenient as broadcasting a lane is a more generally useful operation.

@robertknight robertknight changed the title Optimize Arm GEMM using FMLA (by element) Optimize Arm f32 GEMM using FMLA (by element) Apr 23, 2025
Suppress this new lint until I have time to fix existing failures.
This method is no longer `unsafe`.
On Arm this maps to the `vdupq_laneq_*` instructions. LLVM is conveniently able
to fuse `vdupq_laneq` + `vfmaq` into an indexed FMLA operation. For other
architectures we currently fall back to store + load.
@robertknight robertknight force-pushed the arm-gemm-fmla-by-element branch from 0d6495b to e5f241a Compare April 27, 2025 05:20
Arm supports an FMLA instruction variant [^1] which multiples one lane of a
vector by values in another vector and then accumulates into a destination.
This enables reducing the number of loads from A in each iteration of the GEMM
microkernel by a factor of 4. Instead of broadcasting one scalar value from A at
a time into a register used as an FMA operand, we can load a vector of 4
elements from A and then each FMA specifies which of the 4 elements to use.

The way this is implemented is by using `broadcast_lane` + `mul_add` generic
operations (`vdupq_laneq_f32` followed by `fmla`) which LLVM will fuse into a
single `fmla`.

[^1]: https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLA--by-element---Floating-point-fused-Multiply-Add-to-accumulator--by-element--?lang=en
@robertknight robertknight force-pushed the arm-gemm-fmla-by-element branch from 31764c1 to ffeb6fd Compare April 27, 2025 05:28
@robertknight robertknight marked this pull request as ready for review April 27, 2025 05:31
@robertknight
Copy link
Owner Author

robertknight commented Apr 27, 2025

Test with a ModernBERT base model on an M3 Pro:

 cargo run -p rten-cli -r -- modernbert-base.rten -s sequence_length=512 -n 10 -p

Before: ~420ms mean
After: ~400ms mean
ONNX Runtime (for comparison): ~372ms mean

@robertknight robertknight merged commit 30e2934 into main Apr 27, 2025
3 checks passed
@robertknight robertknight deleted the arm-gemm-fmla-by-element branch April 27, 2025 05:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant