Commit cee825d
[MLAS] Fix rotary interleaved NEON kernel (#26390)
The logic of interleaved NEON kernel is not correct from code review:
1. **Test Code Logic:**
The test code `test_rope.h` allocates the `sin` and `cos` tables based
on the `interleaved` flag:
```c++
size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim;
std::vector<float> sin_data(table_len);
std::vector<float> cos_data(table_len);
```
For the `interleaved = true` case, the test creates `sin` and `cos`
tables of length `rotary_emb_dim / 2`.
2. **AVX2 (fp32) Kernel Logic (`interleaved = true`):**
This kernel loads the `sin`/`cos` data using an index of `i / 2`:
```c++
float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2);
float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2);
```
This logic expects a `sin`/`cos` table of length `rotary_emb_dim / 2`.
**Conclusion: The AVX2 (fp32) kernel is consistent with the test code.**
3. **NEON (fp16) Kernel Logic (`interleaved = true`):**
This kernel loads the `sin`/`cos` data using an index of `i`:
```c++
// Enters loop with sin_val = MlasLoadFloat16x8(sin + i);
//...
// Inside loop, for next iteration:
sin_val = MlasLoadFloat16x8(sin + i + 16);
```
This logic expects a `sin`/`cos` table of length `rotary_emb_dim`.
**Conclusion: The NEON (fp16) kernel is NOT consistent with the test
code.**
### Regression Test
```
cmake --build build/Linux/Release --config Release --target onnxruntime_mlas_test && ./build/Linux/Release/onnxruntime_mlas_test --gtest_filter=NeonFp16RoPE*
```
Before applying the fix, the test failed:
```
[ FAILED ] NeonFp16RoPE.ShortExecute (13 ms)
onnxruntime/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp:66: Failure
Value of: CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat())
Actual: false
Expected: true
Expected bits: 19491 (16.546875) Actual bits: 56596 (-325) @[16], rotary_emb_dim=24, interleaved=true
```
After applying the fix, test passed.
### Summary
The `RopeKernel_Avx2_fp32_Impl<true>` kernel correctly aligns with the
test code (and the fallback implementation) by expecting a `sin`/`cos`
table of length `rotary_emb_dim / 2`.
The `RopeKernel_Fp16_Impl<true>` (NEON) kernel incorrectly expects a
table of length `rotary_emb_dim`. When run against the provided test,
the NEON kernel will read past the end of the `sin_data` and `cos_data`
vectors.
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>1 parent d5aa986 commit cee825d
File tree
2 files changed
+122
-18
lines changed- onnxruntime
- core/mlas/lib
- test/mlas/unittest
2 files changed
+122
-18
lines changedLines changed: 18 additions & 18 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
150 | 150 | | |
151 | 151 | | |
152 | 152 | | |
153 | | - | |
154 | | - | |
| 153 | + | |
| 154 | + | |
155 | 155 | | |
156 | 156 | | |
157 | 157 | | |
| |||
163 | 163 | | |
164 | 164 | | |
165 | 165 | | |
166 | | - | |
167 | | - | |
| 166 | + | |
| 167 | + | |
168 | 168 | | |
169 | 169 | | |
170 | 170 | | |
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
184 | | - | |
185 | | - | |
| 184 | + | |
| 185 | + | |
186 | 186 | | |
187 | 187 | | |
188 | 188 | | |
| |||
201 | 201 | | |
202 | 202 | | |
203 | 203 | | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
210 | 210 | | |
211 | 211 | | |
212 | 212 | | |
| |||
224 | 224 | | |
225 | 225 | | |
226 | 226 | | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
231 | 231 | | |
232 | 232 | | |
233 | 233 | | |
| |||
241 | 241 | | |
242 | 242 | | |
243 | 243 | | |
244 | | - | |
245 | | - | |
| 244 | + | |
| 245 | + | |
246 | 246 | | |
247 | 247 | | |
248 | 248 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
0 commit comments