Commit 1caa3e6
authored
[MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle (#27174)
## Problem Description
The `MatMulNBitsLutGemm` test suite, specifically
`Float32_2Bits_Symmetric_256x256_BlkLen64`, was observing intermittent
failures (flakiness).
The failure manifested as numerical mismatches exceeding the tolerance,
suggesting non-deterministic behavior in the kernel execution.
## Root Cause Analysis
The issue was traced to the usage of `_mm256_i32gather_ps` in
sqnbitgemm_lut_kernel_avx2.cpp
While the gather indices were technically calculating addresses within
the bounds of the allocated buffer, gather instructions on certain AVX2
hardware implementations can exhibit non-deterministic behavior or
subtle performance/prefetching artifacts when operating on specific
stride patterns (in this case, gathering with a stride of 4 floats).
## Solution
This PR replaces the `_mm256_i32gather_ps` instruction with a sequence
of **contiguous loads (`_mm256_loadu_ps`) followed by deterministic
shuffles**.
### How it works:
1. **Contiguous Load**: We load 4 contiguous vectors of 8 floats
elements using `_mm256_loadu_ps`. This is always memory-safe and
deterministic.
2. **Deterministic Shuffle**: We apply a verified sequence of `unpack`
and `permutevar8x32` instructions to rearrange these 32 linearly loaded
elements into the exact same stride-4 layout that the gather instruction
produced.
### Benefits:
* **Stability**: Eliminates the hardware-dependent non-determinism of
gather.
* **Safety**: Usage of `loadu` guarantees we only touch memory within
the explicit range of the 32 elements we intend to load.
* **Correctness**: The shuffle logic was verified against the reference
gather behavior using a C++ reproduction script to ensure bit-exact
layout equivalence.
### Performance
Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64).
Original (Gather): ~55.55 us
Fixed (Load+Shuffle): ~57.79 us
Delta: +2.24 us (~4% slower)
The slight performance regression is expected because replacing a single
hardware gather instruction with a sequence of loadu, unpack, and
permute instructions adds instruction count overhead. However, this is a
necessary tradeoff to ensure deterministic behavior and memory safety
across all AVX2 implementations.
## Verification
* **Tests**: All 9 tests in `MatMulNBitsLutGemm` passed successfully
(including the previously flaky `BlkLen64` case).1 parent a91b2fd commit 1caa3e6
File tree
2 files changed
+46
-14
lines changed- onnxruntime
- core/mlas/lib
- test/contrib_ops
2 files changed
+46
-14
lines changedLines changed: 42 additions & 12 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
187 | 187 | | |
188 | 188 | | |
189 | 189 | | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
190 | 220 | | |
191 | 221 | | |
192 | 222 | | |
193 | | - | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
199 | 226 | | |
200 | 227 | | |
201 | 228 | | |
202 | 229 | | |
203 | 230 | | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
204 | 234 | | |
| 235 | + | |
| 236 | + | |
205 | 237 | | |
206 | 238 | | |
207 | 239 | | |
| |||
222 | 254 | | |
223 | 255 | | |
224 | 256 | | |
225 | | - | |
226 | | - | |
| 257 | + | |
227 | 258 | | |
228 | 259 | | |
229 | 260 | | |
230 | 261 | | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
235 | 265 | | |
236 | 266 | | |
237 | 267 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
371 | 371 | | |
372 | 372 | | |
373 | 373 | | |
374 | | - | |
375 | | - | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
376 | 378 | | |
377 | 379 | | |
378 | 380 | | |
| |||
0 commit comments