Skip to content

Commit 526e527

Browse files
ORT 1.24.0 release cherry pick round 4 (#27202)
| Commit | Commit Title | Author | | :--- | :--- | :--- | | `6861526` | [MLAS] Fix Data Race in MlasLutGemm by Serializing LUT Generation (#27179) | tianleiwu | | `592bcb4` | remove coloredlogs (#27135) | tianleiwu | | `0f153de` | Add API GetTensorElementTypeAndShapeDataReference (#27175) | adrianlizarraga | | `1caa3e6` | [MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle (#27174) | tianleiwu | --------- Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
1 parent 50d4c84 commit 526e527

38 files changed

+313
-98
lines changed

dockerfiles/Dockerfile.source

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sy
1616
FROM mcr.microsoft.com/azurelinux/base/python:3
1717
COPY --from=0 /code/build/Linux/Release/dist /root
1818
COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt
19-
RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl
19+
RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl

docs/python/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ furo
1111
pyquickhelper
1212
pandas
1313
pydot
14-
coloredlogs
1514
flatbuffers
1615
numpy<2.0.0
1716
packaging

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7195,6 +7195,31 @@ struct OrtApi {
71957195
* \since 1.24
71967196
*/
71977197
ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);
7198+
7199+
/** \brief Get the element data type and shape for an OrtValue that represents a Tensor (scalar, dense, or sparse).
7200+
*
7201+
* \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for
7202+
* the shape data. The OrtValue instance's internal shape data is returned directly.
7203+
*
7204+
* \note Returns an error if the underlying OrtValue is not a Tensor.
7205+
*
7206+
* \param[in] value The OrtValue instance.
7207+
* \param[out] elem_type Output parameter set to the tensor element data type.
7208+
* \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array.
7209+
* For a scalar, `shape_data` is NULL and `shape_data_count` is 0.
7210+
* Must not be released as it is owned by the OrtValue instance. This pointer becomes invalid
7211+
* when the OrtValue is released or if the underlying shape data is updated or reallocated.
7212+
* \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`.
7213+
* `shape_data_count` is 0 for a scalar.
7214+
*
7215+
* \snippet{doc} snippets.dox OrtStatus Return Value
7216+
*
7217+
* \since Version 1.24.
7218+
*/
7219+
ORT_API2_STATUS(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
7220+
_Out_ ONNXTensorElementDataType* elem_type,
7221+
_Outptr_result_maybenull_ const int64_t** shape_data,
7222+
_Out_ size_t* shape_data_count);
71987223
};
71997224

72007225
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,6 +2220,19 @@ struct ConstValueImpl : Base<T> {
22202220
const R* GetSparseTensorValues() const;
22212221

22222222
#endif
2223+
2224+
/// <summary>
2225+
/// Returns the tensor's element type and a reference to the tensor's internal shape data. The shape data is owned
2226+
/// by the Ort::Value and becomes invalid when the Ort::Value is destroyed or if the underlying shape data is
2227+
/// updated or reallocated.
2228+
///
2229+
/// For a scalar, shape.shape is nullptr and shape.shape_len is 0.
2230+
///
2231+
/// Wraps OrtApi::GetTensorElementTypeAndShapeDataReference.
2232+
/// </summary>
2233+
/// <param name="elem_type">Output parameter set to the element's data type.</param>
2234+
/// <param name="shape">Output parameter set to the OrtValue instance's shape data and number of elements.</param>
2235+
void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, Shape& shape) const;
22232236
};
22242237

22252238
template <typename T>

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,13 @@ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
23772377

23782378
#endif
23792379

2380+
template <typename T>
2381+
void ConstValueImpl<T>::GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type,
2382+
Shape& shape) const {
2383+
ThrowOnError(GetApi().GetTensorElementTypeAndShapeDataReference(this->p_, &elem_type, &shape.shape,
2384+
&shape.shape_len));
2385+
}
2386+
23802387
template <typename T>
23812388
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
23822389
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));

onnxruntime/core/framework/tensor_type_and_shape.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,64 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
310310
return GetTensorShapeAndTypeHelper(type, shape, dim_params);
311311
}
312312

313+
ORT_API_STATUS_IMPL(OrtApis::GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
314+
_Out_ ONNXTensorElementDataType* elem_type,
315+
_Outptr_result_maybenull_ const int64_t** shape_data,
316+
_Out_ size_t* shape_data_count) {
317+
API_IMPL_BEGIN
318+
if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) {
319+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
320+
"Input parameter `value` must contain a constructed tensor or sparse tensor");
321+
}
322+
323+
if (elem_type == nullptr) {
324+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
325+
"Output parameter `elem_type` must not be NULL");
326+
}
327+
328+
if (shape_data == nullptr) {
329+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
330+
"Output parameter `shape_data` must not be NULL");
331+
}
332+
333+
if (shape_data_count == nullptr) {
334+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
335+
"Output parameter `shape_data_count` must not be NULL");
336+
}
337+
338+
gsl::span<const int64_t> shape_span;
339+
onnxruntime::MLDataType ml_data_type = nullptr;
340+
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
341+
342+
if (value->IsTensor()) {
343+
const Tensor& tensor = value->Get<onnxruntime::Tensor>();
344+
ml_data_type = tensor.DataType();
345+
shape_span = tensor.Shape().GetDims();
346+
} else {
347+
#if !defined(DISABLE_SPARSE_TENSORS)
348+
const SparseTensor& tensor = value->Get<onnxruntime::SparseTensor>();
349+
ml_data_type = tensor.DataType();
350+
shape_span = tensor.DenseShape().GetDims();
351+
#else
352+
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build.");
353+
#endif
354+
}
355+
356+
if (ml_data_type != nullptr) {
357+
type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type);
358+
}
359+
360+
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
361+
return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type");
362+
}
363+
364+
*elem_type = type;
365+
*shape_data = shape_span.empty() ? nullptr : shape_span.data();
366+
*shape_data_count = shape_span.size();
367+
return nullptr;
368+
API_IMPL_END
369+
}
370+
313371
ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
314372
_In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) {
315373
API_IMPL_BEGIN

onnxruntime/core/mlas/lib/qlutgemm.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -548,32 +548,32 @@ MlasLutGemm(
548548

549549
// const int num_groups = static_cast<int>(K / BlkLen);
550550

551-
// Parallelize over M (batch dimension)
552-
// Each iteration processes one row of the activation matrix
551+
// Iterate over M (batch dimension)
552+
// Each iteration processes one row of the activation matrix.
553+
// NOTE: This loop is intentionally serialized. Previous attempts to parallelize
554+
// using MlasTrySimpleParallel caused flaky test failures (race conditions)
555+
// when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight,
556+
// serial execution ensures correctness with negligible performance impact.
553557
// TODO(vraspar): Ideally we have to do block parallelism here
554558

555-
MlasTrySimpleParallel(
556-
threadpool,
557-
static_cast<size_t>(M),
558-
[&](ptrdiff_t ine11) {
559-
const size_t row_offset = static_cast<size_t>(ine11) * K;
560-
const size_t lut_offset = static_cast<size_t>(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT
561-
const size_t scale_bias_offset = static_cast<size_t>(ine11) * lut_scales_size;
562-
563-
// Call the dispatch function for this row
564-
// ggml_tmac_mul_mat_task_init
565-
Dispatch->GenerateLUT(
566-
const_cast<float*>(a_float + row_offset), // Input activation for this row
567-
qlut + lut_offset, // Output LUT for this row
568-
lut_scales + scale_bias_offset, // Scales for this row
569-
lut_biases + scale_bias_offset, // Biases for this row
570-
M,
571-
K,
572-
N,
573-
tmac_params.act_group_size
574-
);
575-
}
576-
);
559+
for (size_t ine11 = 0; ine11 < static_cast<size_t>(M); ine11++) {
560+
const size_t row_offset = ine11 * K;
561+
const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT
562+
const size_t scale_bias_offset = ine11 * lut_scales_size;
563+
564+
// Call the dispatch function for this row
565+
// ggml_tmac_mul_mat_task_init
566+
Dispatch->GenerateLUT(
567+
const_cast<float*>(a_float + row_offset), // Input activation for this row
568+
qlut + lut_offset, // Output LUT for this row
569+
lut_scales + scale_bias_offset, // Scales for this row
570+
lut_biases + scale_bias_offset, // Biases for this row
571+
M,
572+
K,
573+
N,
574+
tmac_params.act_group_size
575+
);
576+
}
577577

578578
// all relevant LUT's have been generated
579579
// equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line

onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,21 +187,53 @@ get_bias_scale()
187187
return 3;
188188
}
189189

190+
static inline void
191+
MlasAvx2LoaduDeinterleave32Ps(const float* src, __m256& v0, __m256& v1, __m256& v2, __m256& v3)
192+
{
193+
// Process 32 activations contiguously using loadu + shuffle.
194+
// This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes,
195+
// which matches the T-MAC weight packing.
196+
// We use loadu + shuffle instead of gather to avoid potential issues with gather
197+
// on some hardware and ensure deterministic behavior.
198+
__m256 vec_b0 = _mm256_loadu_ps(src + 0);
199+
__m256 vec_b1 = _mm256_loadu_ps(src + 8);
200+
__m256 vec_b2 = _mm256_loadu_ps(src + 16);
201+
__m256 vec_b3 = _mm256_loadu_ps(src + 24);
202+
203+
__m256 t0 = _mm256_unpacklo_ps(vec_b0, vec_b1);
204+
__m256 t1 = _mm256_unpackhi_ps(vec_b0, vec_b1);
205+
__m256 t2 = _mm256_unpacklo_ps(vec_b2, vec_b3);
206+
__m256 t3 = _mm256_unpackhi_ps(vec_b2, vec_b3);
207+
208+
__m256 u0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2)));
209+
__m256 u1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2)));
210+
__m256 u2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3)));
211+
__m256 u3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3)));
212+
213+
const __m256i perm_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
214+
v0 = _mm256_permutevar8x32_ps(u0, perm_idx);
215+
v1 = _mm256_permutevar8x32_ps(u1, perm_idx);
216+
v2 = _mm256_permutevar8x32_ps(u2, perm_idx);
217+
v3 = _mm256_permutevar8x32_ps(u3, perm_idx);
218+
}
219+
190220
void
191221
partial_max_g4_int8_k8(float* lut_scales, const float* b)
192222
{
193-
// TODO(vraspar): add support for arm neon
194-
const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0);
195-
__m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1);
196-
__m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1);
197-
__m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1);
198-
__m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1);
223+
__m256 vec_b0, vec_b1, vec_b2, vec_b3;
224+
MlasAvx2LoaduDeinterleave32Ps(b, vec_b0, vec_b1, vec_b2, vec_b3);
225+
199226
const __m256 vec_sign = _mm256_set1_ps(-0.0f);
200227
__m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0);
201228
__m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1);
202229
__m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2);
203230
__m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3);
231+
232+
// The upper bound for the LUT values (mixtures of 4 activations) is the sum
233+
// of their absolute values.
204234
__m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3));
235+
236+
// Reduce max across lanes to find the global maximum sum in this chunk.
205237
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum));
206238
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
207239
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
@@ -222,16 +254,14 @@ lut_ctor_g4_int8_impl(
222254
)
223255
{
224256
__m256 vec_lut[16];
225-
float biases = 0.0;
226-
const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0);
257+
float biases = 0.0f;
227258
float scales = *lut_scales;
228259
float t_scales = scales ? 1.0f / scales : 0.0f;
229260

230261
for (int k = 0; k < act_k / 32; ++k) {
231-
__m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1);
232-
__m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1);
233-
__m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1);
234-
__m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1);
262+
const float* b_chunk = b + k * 32;
263+
__m256 vec_b0, vec_b1, vec_b2, vec_b3;
264+
MlasAvx2LoaduDeinterleave32Ps(b_chunk, vec_b0, vec_b1, vec_b2, vec_b3);
235265

236266
PRAGMA_UNROLL
237267
for (int g = 1; g < 16; g += 2) {

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4802,6 +4802,8 @@ static constexpr OrtApi ort_api_1_to_24 = {
48024802
&OrtApis::EpAssignedNode_GetDomain,
48034803
&OrtApis::EpAssignedNode_GetOperatorType,
48044804
&OrtApis::RunOptionsSetSyncStream,
4805+
&OrtApis::GetTensorElementTypeAndShapeDataReference,
4806+
// End of Version 24 - DO NOT MODIFY ABOVE (see above text for more information)
48054807
};
48064808

48074809
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
@@ -4838,6 +4840,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz
48384840

48394841
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
48404842
static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change");
4843+
static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change");
48414844

48424845
// So that nobody forgets to finish an API version, this check will serve as a reminder:
48434846
static_assert(std::string_view(ORT_VERSION) == "1.24.0",

onnxruntime/core/session/ort_apis.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,4 +808,9 @@ ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgrap
808808
ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
809809
ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
810810
ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
811+
812+
ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
813+
_Out_ ONNXTensorElementDataType* elem_type,
814+
_Outptr_result_maybenull_ const int64_t** shape_data,
815+
_Out_ size_t* shape_data_count);
811816
} // namespace OrtApis

0 commit comments

Comments
 (0)