Commit 6861526
authored
[MLAS] Fix Data Race in MlasLutGemm by Serializing LUT Generation (#27179)
## Problem Description
The `MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` test
was exhibiting flaky behavior (failure rate ~2-20%) with numerical
mismatches.
Investigation revealed a **race condition** in the
[GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326)
step within
[MlasLutGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/inc/mlas_qnbit.h#L328).
When the batch size `M > 1`,
[MlasLutGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/inc/mlas_qnbit.h#L328)
attempted to parallelize the LUT generation over the batch dimension
using `MlasTrySimpleParallel`. However, the underlying
[GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326)
implementation (specifically shared usage of `lut_scales`/`lut_biases`
or internal buffers) is not thread-safe for concurrent execution on the
same destination buffers or related state. This led to corruption of the
Look-Up Tables or scales, causing random output errors.
## Solution
This PR modifies
[onnxruntime/core/mlas/lib/qlutgemm.cpp](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/qlutgemm.cpp)
to **serialize the
[GenerateLUT](file:///home/tlwu/onnxruntime/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#324-355)
loop**.
Instead of using `MlasTrySimpleParallel`, we now use a simple `for` loop
to process each row of the batch sequentially.
**Performance Impact:**
The
[GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326)
step is computationally lightweight compared to the subsequent
[TMACComputeGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L505)
matrix multiplication. Serializing this setup step has negligible impact
on overall inference latency (micro-benchmarks showed no measurable
regression), but effectively eliminates the race condition.
## Verification
* **Reproduction:** The issue was reliably reproduced by running
`MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` in a loop
(failing ~1 in 5 times).
* **Verification:** After applying the fix, the same test passed **50/50
iterations** consistently.
* **Regression Testing:** Standard `MatMulNBitsLutGemm` tests (including
`BlkLen64` and `M=1` cases) continue to pass.1 parent db383a9 commit 6861526
1 file changed
+24
-24
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
548 | 548 | | |
549 | 549 | | |
550 | 550 | | |
551 | | - | |
552 | | - | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
553 | 557 | | |
554 | 558 | | |
555 | | - | |
556 | | - | |
557 | | - | |
558 | | - | |
559 | | - | |
560 | | - | |
561 | | - | |
562 | | - | |
563 | | - | |
564 | | - | |
565 | | - | |
566 | | - | |
567 | | - | |
568 | | - | |
569 | | - | |
570 | | - | |
571 | | - | |
572 | | - | |
573 | | - | |
574 | | - | |
575 | | - | |
576 | | - | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
577 | 577 | | |
578 | 578 | | |
579 | 579 | | |
| |||
0 commit comments