Skip to content

Commit 6861526

Browse files
authored
[MLAS] Fix Data Race in MlasLutGemm by Serializing LUT Generation (#27179)
## Problem Description The `MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` test was exhibiting flaky behavior (failure rate ~2-20%) with numerical mismatches. Investigation revealed a **race condition** in the [GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326) step within [MlasLutGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/inc/mlas_qnbit.h#L328). When the batch size `M > 1`, [MlasLutGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/inc/mlas_qnbit.h#L328) attempted to parallelize the LUT generation over the batch dimension using `MlasTrySimpleParallel`. However, the underlying [GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326) implementation (specifically shared usage of `lut_scales`/`lut_biases` or internal buffers) is not thread-safe for concurrent execution on the same destination buffers or related state. This led to corruption of the Look-Up Tables or scales, causing random output errors. ## Solution This PR modifies [onnxruntime/core/mlas/lib/qlutgemm.cpp](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/qlutgemm.cpp) to **serialize the [GenerateLUT](file:///home/tlwu/onnxruntime/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#324-355) loop**. Instead of using `MlasTrySimpleParallel`, we now use a simple `for` loop to process each row of the batch sequentially. **Performance Impact:** The [GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326) step is computationally lightweight compared to the subsequent [TMACComputeGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L505) matrix multiplication. Serializing this setup step has negligible impact on overall inference latency (micro-benchmarks showed no measurable regression), but effectively eliminates the race condition. ## Verification * **Reproduction:** The issue was reliably reproduced by running `MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` in a loop (failing ~1 in 5 times). * **Verification:** After applying the fix, the same test passed **50/50 iterations** consistently. * **Regression Testing:** Standard `MatMulNBitsLutGemm` tests (including `BlkLen64` and `M=1` cases) continue to pass.
1 parent db383a9 commit 6861526

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

onnxruntime/core/mlas/lib/qlutgemm.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -548,32 +548,32 @@ MlasLutGemm(
548548

549549
// const int num_groups = static_cast<int>(K / BlkLen);
550550

551-
// Parallelize over M (batch dimension)
552-
// Each iteration processes one row of the activation matrix
551+
// Iterate over M (batch dimension)
552+
// Each iteration processes one row of the activation matrix.
553+
// NOTE: This loop is intentionally serialized. Previous attempts to parallelize
554+
// using MlasTrySimpleParallel caused flaky test failures (race conditions)
555+
// when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight,
556+
// serial execution ensures correctness with negligible performance impact.
553557
// TODO(vraspar): Ideally we have to do block parallelism here
554558

555-
MlasTrySimpleParallel(
556-
threadpool,
557-
static_cast<size_t>(M),
558-
[&](ptrdiff_t ine11) {
559-
const size_t row_offset = static_cast<size_t>(ine11) * K;
560-
const size_t lut_offset = static_cast<size_t>(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT
561-
const size_t scale_bias_offset = static_cast<size_t>(ine11) * lut_scales_size;
562-
563-
// Call the dispatch function for this row
564-
// ggml_tmac_mul_mat_task_init
565-
Dispatch->GenerateLUT(
566-
const_cast<float*>(a_float + row_offset), // Input activation for this row
567-
qlut + lut_offset, // Output LUT for this row
568-
lut_scales + scale_bias_offset, // Scales for this row
569-
lut_biases + scale_bias_offset, // Biases for this row
570-
M,
571-
K,
572-
N,
573-
tmac_params.act_group_size
574-
);
575-
}
576-
);
559+
for (size_t ine11 = 0; ine11 < static_cast<size_t>(M); ine11++) {
560+
const size_t row_offset = ine11 * K;
561+
const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT
562+
const size_t scale_bias_offset = ine11 * lut_scales_size;
563+
564+
// Call the dispatch function for this row
565+
// ggml_tmac_mul_mat_task_init
566+
Dispatch->GenerateLUT(
567+
const_cast<float*>(a_float + row_offset), // Input activation for this row
568+
qlut + lut_offset, // Output LUT for this row
569+
lut_scales + scale_bias_offset, // Scales for this row
570+
lut_biases + scale_bias_offset, // Biases for this row
571+
M,
572+
K,
573+
N,
574+
tmac_params.act_group_size
575+
);
576+
}
577577

578578
// all relevant LUT's have been generated
579579
// equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line

0 commit comments

Comments
 (0)