-
Notifications
You must be signed in to change notification settings - Fork 11
Optimize Arm f32 GEMM using FMLA by element #679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There is a The broadcast lane + FMA fusion is convenient as broadcasting a lane is a more generally useful operation. |
Suppress this new lint until I have time to fix existing failures.
This method is no longer `unsafe`.
On Arm this maps to the `vdupq_laneq_*` instructions. LLVM is conveniently able to fuse `vdupq_laneq` + `vfmaq` into an indexed FMLA operation. For other architectures we currently fall back to store + load.
0d6495b
to
e5f241a
Compare
Arm supports an FMLA instruction variant [^1] which multiples one lane of a vector by values in another vector and then accumulates into a destination. This enables reducing the number of loads from A in each iteration of the GEMM microkernel by a factor of 4. Instead of broadcasting one scalar value from A at a time into a register used as an FMA operand, we can load a vector of 4 elements from A and then each FMA specifies which of the 4 elements to use. The way this is implemented is by using `broadcast_lane` + `mul_add` generic operations (`vdupq_laneq_f32` followed by `fmla`) which LLVM will fuse into a single `fmla`. [^1]: https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLA--by-element---Floating-point-fused-Multiply-Add-to-accumulator--by-element--?lang=en
31764c1
to
ffeb6fd
Compare
Test with a ModernBERT base model on an M3 Pro:
Before: ~420ms mean |
Similar to #679, use the by-element variant of UDOT [^1] to reduce the number of loads in the int8 GEMM microkernel. Instead of performing 4 scalar 32-bit loads and broadcasting the result, perform one 128-bit load and use indexed UDOT to implicitly broadcast a 32-bit lane for each of the 4 rows. [^1]: https://developer.arm.com/documentation/100069/0609/A64-SIMD-Vector-Instructions/UDOT--vector--by-element-
Similar to #679, use the by-element variant of UDOT [^1] to reduce the number of loads in the int8 GEMM microkernel. Instead of performing 4 scalar 32-bit loads and broadcasting the result, perform one 128-bit load and use indexed UDOT to implicitly broadcast a 32-bit lane for each of the 4 rows. [^1]: https://developer.arm.com/documentation/100069/0609/A64-SIMD-Vector-Instructions/UDOT--vector--by-element-
Similar to #679, use the by-element variant of UDOT [^1] to reduce the number of loads in the int8 GEMM microkernel. Instead of performing 4 scalar 32-bit loads and broadcasting the result, perform one 128-bit load and use indexed UDOT to implicitly broadcast a 32-bit lane for each of the 4 rows. [^1]: https://developer.arm.com/documentation/100069/0609/A64-SIMD-Vector-Instructions/UDOT--vector--by-element-
Arm supports an FMLA instruction variant 1 which multiples one lane of a
vector by values in another vector and then accumulates into a destination.
This enables reducing the number of loads from A in each iteration of the GEMM
microkernel by a factor of 4. Instead of broadcasting one scalar value from A at
a time into a register used as an FMA operand, we can load a vector of 4
elements from A and then each FMA specifies which of the 4 elements to use.
The way this is implemented is by using
broadcast_lane
+mul_add
genericoperations (
vdupq_laneq_f32
followed byfmla
) which LLVM will fuse into asingle
fmla
.Future work:
Footnotes
https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLA--by-element---Floating-point-fused-Multiply-Add-to-accumulator--by-element--?lang=en ↩