Skip to content

Commit cee825d

Browse files
tianleiwuCopilot
andauthored
[MLAS] Fix rotary interleaved NEON kernel (#26390)
The logic of interleaved NEON kernel is not correct from code review: 1. **Test Code Logic:** The test code `test_rope.h` allocates the `sin` and `cos` tables based on the `interleaved` flag: ```c++ size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; std::vector<float> sin_data(table_len); std::vector<float> cos_data(table_len); ``` For the `interleaved = true` case, the test creates `sin` and `cos` tables of length `rotary_emb_dim / 2`. 2. **AVX2 (fp32) Kernel Logic (`interleaved = true`):** This kernel loads the `sin`/`cos` data using an index of `i / 2`: ```c++ float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); ``` This logic expects a `sin`/`cos` table of length `rotary_emb_dim / 2`. **Conclusion: The AVX2 (fp32) kernel is consistent with the test code.** 3. **NEON (fp16) Kernel Logic (`interleaved = true`):** This kernel loads the `sin`/`cos` data using an index of `i`: ```c++ // Enters loop with sin_val = MlasLoadFloat16x8(sin + i); //... // Inside loop, for next iteration: sin_val = MlasLoadFloat16x8(sin + i + 16); ``` This logic expects a `sin`/`cos` table of length `rotary_emb_dim`. **Conclusion: The NEON (fp16) kernel is NOT consistent with the test code.** ### Regression Test ``` cmake --build build/Linux/Release --config Release --target onnxruntime_mlas_test && ./build/Linux/Release/onnxruntime_mlas_test --gtest_filter=NeonFp16RoPE* ``` Before applying the fix, the test failed: ``` [ FAILED ] NeonFp16RoPE.ShortExecute (13 ms) onnxruntime/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp:66: Failure Value of: CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat()) Actual: false Expected: true Expected bits: 19491 (16.546875) Actual bits: 56596 (-325) @[16], rotary_emb_dim=24, interleaved=true ``` After applying the fix, test passed. ### Summary The `RopeKernel_Avx2_fp32_Impl<true>` kernel correctly aligns with the test code (and the fallback implementation) by expecting a `sin`/`cos` table of length `rotary_emb_dim / 2`. The `RopeKernel_Fp16_Impl<true>` (NEON) kernel incorrectly expects a table of length `rotary_emb_dim`. When run against the provided test, the NEON kernel will read past the end of the `sin_data` and `cos_data` vectors. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent d5aa986 commit cee825d

File tree

2 files changed

+122
-18
lines changed

2 files changed

+122
-18
lines changed

onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ RopeKernel_Fp16_Impl<true>(
150150
if (i + 15 < dim) {
151151
float16x8_t x0 = MlasLoadFloat16x8(input + i);
152152
float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
153-
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
154-
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
153+
float16x8_t sin_val = MlasLoadFloat16x8(sin + i / 2);
154+
float16x8_t cos_val = MlasLoadFloat16x8(cos + i / 2);
155155
for (; i + 31 < dim; i += 16) {
156156
float16x8_t real = vuzp1q_f16(x0, x1);
157157
float16x8_t imag = vuzp2q_f16(x0, x1);
@@ -163,8 +163,8 @@ RopeKernel_Fp16_Impl<true>(
163163
MlasStoreFloat16x8(output + i + 8, y1);
164164
x0 = MlasLoadFloat16x8(input + i + 16);
165165
x1 = MlasLoadFloat16x8(input + i + 24);
166-
sin_val = MlasLoadFloat16x8(sin + i + 16);
167-
cos_val = MlasLoadFloat16x8(cos + i + 16);
166+
sin_val = MlasLoadFloat16x8(sin + (i + 16) / 2);
167+
cos_val = MlasLoadFloat16x8(cos + (i + 16) / 2);
168168
}
169169
float16x8_t real = vuzp1q_f16(x0, x1);
170170
float16x8_t imag = vuzp2q_f16(x0, x1);
@@ -181,8 +181,8 @@ RopeKernel_Fp16_Impl<true>(
181181
float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
182182
float16x4_t real = vuzp1_f16(x0, x1);
183183
float16x4_t imag = vuzp2_f16(x0, x1);
184-
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
185-
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
184+
float16x4_t sin_val = MlasLoadFloat16x4(sin + i / 2);
185+
float16x4_t cos_val = MlasLoadFloat16x4(cos + i / 2);
186186
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
187187
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
188188
float16x4_t y0 = vzip1_f16(real_out, imag_out);
@@ -201,12 +201,12 @@ RopeKernel_Fp16_Impl<true>(
201201
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
202202
real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
203203
imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
204-
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
205-
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
206-
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
207-
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
208-
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
209-
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
204+
sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val);
205+
sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val);
206+
sin_val = MlasLoadLaneFloat16x4<2>(sin + i / 2 + 2, sin_val);
207+
cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val);
208+
cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val);
209+
cos_val = MlasLoadLaneFloat16x4<2>(cos + i / 2 + 2, cos_val);
210210
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
211211
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
212212
MlasStoreLaneFloat16x4<0>(output + i, real_out);
@@ -224,10 +224,10 @@ RopeKernel_Fp16_Impl<true>(
224224
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
225225
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
226226
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
227-
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
228-
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
229-
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
230-
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
227+
sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val);
228+
sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val);
229+
cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val);
230+
cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val);
231231
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
232232
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
233233
MlasStoreLaneFloat16x4<0>(output + i, real_out);
@@ -241,8 +241,8 @@ RopeKernel_Fp16_Impl<true>(
241241
float16x4_t cos_val = MlasZeroFloat16x4();
242242
real = MlasLoadLaneFloat16x4<0>(input + i, real);
243243
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
244-
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
245-
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
244+
sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val);
245+
cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val);
246246
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
247247
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
248248
MlasStoreLaneFloat16x4<0>(output + i, real_out);
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
test_rope_neon_fp16.cpp
10+
11+
Abstract:
12+
13+
Tests for MLAS fp16 RoPE on NEON.
14+
15+
--*/
16+
17+
#include <vector>
18+
#include <cmath>
19+
20+
#include "core/mlas/inc/mlas.h"
21+
22+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
23+
24+
#include "test_util.h"
25+
#include "core/mlas/lib/mlasi.h"
26+
#include "core/mlas/lib/rotary_embedding.h"
27+
#include "core/mlas/lib/rotary_embedding_kernel_neon.h"
28+
29+
class MlasNeonFp16RoPETest : public MlasTestBase {
30+
private:
31+
const float Pi = 2 * std::acos(0.0f);
32+
33+
void Test(size_t rotary_emb_dim, bool interleaved) {
34+
// Per kernel logic (both fallback and optimized), the sin/cos tables
35+
// are always half the rotary embedding dimension.
36+
const size_t table_len = rotary_emb_dim / 2;
37+
38+
std::vector<MLAS_FP16> input(rotary_emb_dim);
39+
std::vector<MLAS_FP16> sin_data(table_len);
40+
std::vector<MLAS_FP16> cos_data(table_len);
41+
std::vector<MLAS_FP16> output_ref(rotary_emb_dim);
42+
std::vector<MLAS_FP16> output_impl(rotary_emb_dim);
43+
44+
// Initialize input data
45+
for (size_t i = 0; i < rotary_emb_dim; ++i) {
46+
input[i] = MLAS_FP16(static_cast<float>(i + 1));
47+
}
48+
49+
// Initialize sin/cos tables
50+
for (size_t i = 0; i < table_len; ++i) {
51+
float theta = static_cast<float>(i) / 1000.0f * Pi;
52+
sin_data[i] = MLAS_FP16(std::sin(theta));
53+
cos_data[i] = MLAS_FP16(std::cos(theta));
54+
}
55+
56+
// Call fallback implementation
57+
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data());
58+
59+
// Call dispatched implementation (which should pick up the NEON kernel)
60+
MlasRotaryEmbedOneRow<MLAS_FP16>(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data());
61+
62+
// Compare results
63+
for (size_t i = 0; i < rotary_emb_dim; i++) {
64+
ASSERT_TRUE(CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat()))
65+
<< "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")"
66+
<< " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")"
67+
<< " @[" << i << "], "
68+
<< "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved;
69+
}
70+
}
71+
72+
public:
73+
static const char* GetTestSuiteName() {
74+
return "NeonFp16RoPE";
75+
}
76+
77+
void ExecuteShort(void) override {
78+
// Test dimensions that cover main loops and various remainders
79+
Test(6, false);
80+
Test(6, true);
81+
Test(16, false);
82+
Test(16, true);
83+
Test(24, false);
84+
Test(24, true);
85+
Test(32, false);
86+
Test(32, true);
87+
Test(42, false);
88+
Test(42, true);
89+
Test(64, false);
90+
Test(64, true);
91+
Test(70, false);
92+
Test(70, true);
93+
}
94+
};
95+
96+
static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
97+
size_t count = 0;
98+
if (is_short_execute) {
99+
count += MlasDirectShortExecuteTests<MlasNeonFp16RoPETest>::RegisterShortExecute();
100+
}
101+
return count;
102+
});
103+
104+
#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)

0 commit comments

Comments
 (0)