Skip to content

Commit b3db7f5

Browse files
authored
s8s8_sym_quant gtest: enable reorder, post-op coverage and fix leak/ref bugs (#557)
- Added optional const GroupScaleParam* group_scale argument to IUal::reorder(), UalDlp::reorder() and UalRef::reorder() for the sym_quant APIs. The argument defaults to nullptr, keeping existing callers source-compatible. - UalDlp::reorder() now selects the sym_quant reorder APIs aocl_get_reorder_buf_size_s8s8s32os32_sym_quant() and aocl_reorder_s8s8s32os32_sym_quant() when given s8 input, s32 accumulation, f32/bf16 output and a non-null group_scale, normalizing group_size == 0 to the full K dimension since the sym_quant APIs require a strictly positive group size. - Fixed a leak of the A pack buffer in the GEMV (m=1) path of s8s8_sym_quant kernel. - Hardened group pre-op validation in dlp_gemm_translate_to_group_postops_list() to also reject scale-factor and zero-point arrays shorter than the required length, i.e., m*(ceil(k/group_size)) for A matrix and n*(ceil(k/group_size)) for B matrix. - Fixed the reference skipping group-scale de-quantization under post-ops in RefUalPlan::execute() since it was taking the integer GEMM's needsF32Intermediate path. isS8S8GroupScale is now computed earlier and excluded from needsF32Intermediate, letting these cases fall through to the sym_quant reference which de-quantizes and then applies post-ops via applyPostOps(). AMD-Internal: [CPUPL-8537] Signed-off-by: Arnav Sharma <Arnav.Sharma@amd.com>
1 parent 455d57e commit b3db7f5

11 files changed

Lines changed: 113 additions & 52 deletions

File tree

bench/bench_gemm.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class OptimizedGemmBenchmark : public ConcreteUAL
150150
// Apply memory tag for B (reorder and pack are mutually exclusive)
151151
if (config.reorderB) {
152152
Matrix B_reordered;
153-
this->reorder(B_, B_reordered, a_type_, b_type_, c_type_,
154-
acc_type_);
153+
this->reorder(B_, B_reordered, a_type_, b_type_, c_type_, acc_type_,
154+
config.group_scale_param.get());
155155
B_ = std::move(B_reordered);
156156
// Reorder handles transposition; reset trans flag for GEMM call
157157
transB_ = false;

classic/frame/dlp_gemm_post_ops.c

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ dlp_gemm_translate_to_group_postops_list(dlp_group_post_op* metadata,
137137
md_t n,
138138
md_t k)
139139
{
140-
(void)m;
141-
(void)n;
142140
if ((metadata == NULL) || (metadata->seq_length <= 0)) {
143141
dlp_gemm_set_group_post_ops_node_params(post_op_list, 0, NULL, NULL, 0,
144142
0, NULL, NULL, 0, 0,
@@ -166,25 +164,41 @@ dlp_gemm_translate_to_group_postops_list(dlp_group_post_op* metadata,
166164
if (((metadata->a_zp)->zero_point_len > 0)
167165
&& ((metadata->a_zp)->zero_point == NULL))
168166
return DLP_CLSC_NULL_POINTER;
167+
168+
if ((metadata->a_zp)->zero_point_len
169+
< (m * ((k + group_size - 1) / group_size)))
170+
return DLP_CLSC_INVALID_ZP_LEN;
169171
}
170172

171173
if (metadata->a_scl != NULL) {
172174
if (((metadata->a_scl)->scale_factor_len > 0)
173175
&& ((metadata->a_scl)->scale_factor == NULL))
174176
return DLP_CLSC_NULL_POINTER;
177+
178+
if ((metadata->a_scl)->scale_factor_len
179+
< (m * ((k + group_size - 1) / group_size)))
180+
return DLP_CLSC_INVALID_SF_LEN;
175181
}
176182

177183
if (metadata->b_zp != NULL) {
178184
/* check for validity of pre-ops */
179185
if (((metadata->b_zp)->zero_point_len > 0)
180186
&& ((metadata->b_zp)->zero_point == NULL))
181187
return DLP_CLSC_NULL_POINTER;
188+
189+
if ((metadata->b_zp)->zero_point_len
190+
< (n * ((k + group_size - 1) / group_size)))
191+
return DLP_CLSC_INVALID_ZP_LEN;
182192
}
183193

184194
if (metadata->b_scl != NULL) {
185195
if (((metadata->b_scl)->scale_factor_len > 0)
186196
&& ((metadata->b_scl)->scale_factor == NULL))
187197
return DLP_CLSC_NULL_POINTER;
198+
199+
if ((metadata->b_scl)->scale_factor_len
200+
< (n * ((k + group_size - 1) / group_size)))
201+
return DLP_CLSC_INVALID_SF_LEN;
188202
}
189203

190204
if ((metadata->a_scl != NULL) && (metadata->b_scl != NULL)

classic/frame/s8s8s32/dlp_gemm_s8s8s32_sym_quant.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ DLP_GEMV2(int8_t, int8_t, int32_t, s8s8s32o32_sym_quant)
266266
} // jc loop
267267

268268
// Release pack buffers.
269-
if (mtag_b == PACK && (pack_a_buffer_s8s8s32os32 != NULL)) {
269+
if ((mtag_a == PACK) && (pack_a_buffer_s8s8s32os32 != NULL)) {
270270
dlp_free_page_aligned(pack_a_buffer_s8s8s32os32);
271271
}
272272
}

include/classic/dlp_errors.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ typedef enum
5252
*/
5353
DLP_CLSC_INVALID_MATRIX_TYPE, /**< Invalid matrix type specified */
5454
DLP_CLSC_INVALID_GROUP_DIMENSION, /**< Invalid group dimension specified */
55-
DLP_CLSC_TYPE_MISMATCH, /**< Data type mismatch encountered */
56-
DLP_CLSC_INVALID_JIT_KERNEL, /**< JIT kernel generation failed or no
57-
fallback kernel available */
55+
DLP_CLSC_INVALID_SF_LEN, /**< Invalid scale factor length specified */
56+
DLP_CLSC_INVALID_ZP_LEN, /**< Invalid zero point length specified */
57+
DLP_CLSC_TYPE_MISMATCH, /**< Data type mismatch encountered */
58+
DLP_CLSC_INVALID_JIT_KERNEL, /**< JIT kernel generation failed or no
59+
fallback kernel available */
5860
DLP_CLSC_INVALID_KERNEL, /**< Static kernel not found for given parameters
5961
*/
6062
DLP_CLSC_ERROR_MAX /**< Maximum error code value (for bounds checking) */

tests/adaptors/dlp/ual_dlp.cc

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,13 @@ UalDlp::toString(UALType type)
116116
* @return UALError Error code indicating success or failure
117117
*/
118118
UALError
119-
UalDlp::reorder(const Matrix& in,
120-
Matrix& out,
121-
MatrixType A_type,
122-
MatrixType B_type,
123-
MatrixType C_type,
124-
MatrixType accType)
119+
UalDlp::reorder(const Matrix& in,
120+
Matrix& out,
121+
MatrixType A_type,
122+
MatrixType B_type,
123+
MatrixType C_type,
124+
MatrixType accType,
125+
const GroupScaleParam* group_scale)
125126
{
126127
dlp_metadata_t meta;
127128
meta.error_hndl.error_code = DLP_CLSC_SUCCESS;
@@ -130,6 +131,13 @@ UalDlp::reorder(const Matrix& in,
130131
md_t effective_rows = in.getEffectiveRows();
131132
md_t effective_cols = in.getEffectiveCols();
132133

134+
// Detect symmetric-quantization reorder path:
135+
// s8 input, s32 accumulation, f32 or bf16 output, and group_scale provided.
136+
const bool sym_quant =
137+
(group_scale != nullptr) && (in.getMatrixType() == MatrixType::s8)
138+
&& (accType == MatrixType::s32)
139+
&& (C_type == MatrixType::f32 || C_type == MatrixType::bf16);
140+
133141
// Determine appropriate reorder function based on input type and GEMM
134142
// context The A, B, C types provide context for optimal reordering strategy
135143
msz_t alloc_bytes = 0;
@@ -166,14 +174,20 @@ UalDlp::reorder(const Matrix& in,
166174
effective_cols, &meta);
167175
}
168176
} else if (in.getMatrixType() == MatrixType::s8) {
169-
// For s8, consider the accumulation type and output type
170-
if (accType == MatrixType::s32) {
171-
alloc_bytes = aocl_get_reorder_buf_size_s8s8s32os32(
177+
// For s8, select sym_quant or standard reorder based on GEMM context
178+
if (sym_quant) {
179+
// group_size=0 means "full K dimension"; normalize before calling
180+
// the AOCL sym_quant APIs which require a strictly positive value.
181+
md_t gs = group_scale->getGroupSize();
182+
if (gs == 0) {
183+
gs = effective_rows; // effective_rows == K for B matrix
184+
}
185+
DLP_SYMM_STAT_QUANT symq = { gs };
186+
alloc_bytes = aocl_get_reorder_buf_size_s8s8s32os32_sym_quant(
172187
in.getLayout() == MatrixLayout::ROW_MAJOR ? 'r' : 'c',
173188
in.isTransposed() ? 't' : 'n', 'B', effective_rows,
174-
effective_cols, &meta);
189+
effective_cols, &symq, &meta);
175190
} else {
176-
// Handle other accumulation types - for now, fall back to standard
177191
alloc_bytes = aocl_get_reorder_buf_size_s8s8s32os32(
178192
in.getLayout() == MatrixLayout::ROW_MAJOR ? 'r' : 'c',
179193
in.isTransposed() ? 't' : 'n', 'B', effective_rows,
@@ -274,13 +288,31 @@ UalDlp::reorder(const Matrix& in,
274288
&meta);
275289
break;
276290
case MatrixType::s8:
277-
aocl_reorder_s8s8s32os32(
278-
layout, in.isTransposed() ? 't' : 'n', 'B',
279-
reinterpret_cast<const int8_t*>(
280-
in.getMatrixData().getMatrixPtr()),
281-
reinterpret_cast<int8_t*>(out.getMatrixData().getMatrixPtr()),
282-
effective_rows, effective_cols, in.getLeadingDimension(),
283-
&meta);
291+
if (sym_quant) {
292+
// group_size=0 means "full K"; normalize to avoid div-by-zero.
293+
md_t gs = group_scale->getGroupSize();
294+
if (gs == 0) {
295+
gs = effective_rows;
296+
}
297+
DLP_SYMM_STAT_QUANT symq = { gs };
298+
aocl_reorder_s8s8s32os32_sym_quant(
299+
layout, in.isTransposed() ? 't' : 'n', 'B',
300+
reinterpret_cast<const int8_t*>(
301+
in.getMatrixData().getMatrixPtr()),
302+
reinterpret_cast<int8_t*>(
303+
out.getMatrixData().getMatrixPtr()),
304+
effective_rows, effective_cols, in.getLeadingDimension(),
305+
&symq, &meta);
306+
} else {
307+
aocl_reorder_s8s8s32os32(
308+
layout, in.isTransposed() ? 't' : 'n', 'B',
309+
reinterpret_cast<const int8_t*>(
310+
in.getMatrixData().getMatrixPtr()),
311+
reinterpret_cast<int8_t*>(
312+
out.getMatrixData().getMatrixPtr()),
313+
effective_rows, effective_cols, in.getLeadingDimension(),
314+
&meta);
315+
}
284316
break;
285317
case MatrixType::fp16:
286318
if (A_type == MatrixType::f32 && C_type == MatrixType::f32

tests/adaptors/ref/ual_plan_ref.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,12 @@ RefUalPlan::execute()
261261
|| (aType == MatrixType::s8 && bType == MatrixType::s8));
262262
bool isBf16Gemm = (aType == MatrixType::bf16 && bType == MatrixType::bf16);
263263
bool isF32Gemm = (aType == MatrixType::f32 && bType == MatrixType::f32);
264-
bool needsF32Intermediate =
265-
(isIntegerGemm || isBf16Gemm || isF32Gemm) && hasPostOps;
264+
265+
bool isS8S8GroupScale =
266+
(aType == MatrixType::s8 && bType == MatrixType::s8 && m_group_scale);
267+
268+
bool needsF32Intermediate = (isIntegerGemm || isBf16Gemm || isF32Gemm)
269+
&& hasPostOps && !isS8S8GroupScale;
266270

267271
if (needsF32Intermediate && !ualRef.checkValidGemmParams(A, B, C, true)) {
268272
needsF32Intermediate = false;
@@ -301,9 +305,6 @@ RefUalPlan::execute()
301305
// Uses specialized ref that handles per-group scale application during
302306
// K-panel accumulation, which is required for correct results when
303307
// group_size > 0.
304-
bool isS8S8GroupScale =
305-
(aType == MatrixType::s8 && bType == MatrixType::s8 && m_group_scale);
306-
307308
if (isS8S8GroupScale) {
308309
md_t gs = m_group_scale->getGroupSize();
309310

tests/adaptors/ref/ual_ref.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ UalRef::reorder(const Matrix& in,
243243
MatrixType A_type,
244244
MatrixType B_type,
245245
MatrixType C_type,
246-
MatrixType accType)
246+
MatrixType accType,
247+
const GroupScaleParam* /*group_scale*/)
247248
{
248249
/*
249250
Reordering operation in reference is

tests/classic/test_gemm.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ class GemmParameterizedTest : public ::testing::TestWithParam<GemmTestConfig>
10921092

10931093
dlp_reorder_status = ual_test_->reorder(
10941094
B, B_reordered, config_.a_type, config_.b_type, config_.c_type,
1095-
config_.acc_type);
1095+
config_.acc_type, config_.group_scale_param.get());
10961096

10971097
// Skip test if DLP reorder is not supported
10981098
if (dlp_reorder_status == UALError::UAL_NOT_SUPPORTED) {
@@ -1112,7 +1112,8 @@ class GemmParameterizedTest : public ::testing::TestWithParam<GemmTestConfig>
11121112
if (params_valid) {
11131113
ref_reorder_status = ual_ref_->reorder(
11141114
B_ref, B_ref_reordered, config_.a_type, config_.b_type,
1115-
config_.c_type, config_.acc_type);
1115+
config_.c_type, config_.acc_type,
1116+
config_.group_scale_param.get());
11161117

11171118
if (ref_reorder_status == UALError::UAL_SUCCESS) {
11181119
// For bf16×s4 and bf16×u4, the reference uses row-major B

tests/include/adaptors/dlp/ual_dlp.hh

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
namespace dlp::testing::classic {
3535

3636
using dlp::testing::framework::BatchGroup;
37+
using dlp::testing::framework::GroupScaleParam;
3738
using dlp::testing::framework::IUal;
3839
using dlp::testing::framework::Matrix;
3940
using dlp::testing::framework::MatrixLayout;
@@ -112,14 +113,17 @@ class UalDlp : public IUal
112113
* @param B_type Type of matrix B in GEMM context
113114
* @param C_type Type of matrix C in GEMM context
114115
* @param accType Accumulation type
116+
* @param group_scale Optional symmetric-quantization group-scale
117+
* parameters; when non-null selects the sym_quant reorder path
115118
* @return UALError Error code indicating success or failure
116119
*/
117-
UALError reorder(const Matrix& in,
118-
Matrix& out,
119-
MatrixType A_type,
120-
MatrixType B_type,
121-
MatrixType C_type,
122-
MatrixType accType) override;
120+
UALError reorder(const Matrix& in,
121+
Matrix& out,
122+
MatrixType A_type,
123+
MatrixType B_type,
124+
MatrixType C_type,
125+
MatrixType accType,
126+
const GroupScaleParam* group_scale = nullptr) override;
123127

124128
UALError batch_gemm(std::vector<BatchGroup>& groups,
125129
MatrixType accType) override;

tests/include/adaptors/ref/ual_ref.hh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
namespace dlp::testing::classic {
3636

3737
using dlp::testing::framework::BatchGroup;
38+
using dlp::testing::framework::GroupScaleParam;
3839
using dlp::testing::framework::IUal;
3940
using dlp::testing::framework::Matrix;
4041
using dlp::testing::framework::MatrixLayout;
@@ -114,14 +115,16 @@ class UalRef : public IUal
114115
* @param B_type Type of matrix B in GEMM context
115116
* @param C_type Type of matrix C in GEMM context
116117
* @param accType Accumulation type
118+
* @param group_scale Optional symmetric-quantization group-scale parameters
117119
* @return UALError Error code indicating success or failure
118120
*/
119-
UALError reorder(const Matrix& in,
120-
Matrix& out,
121-
MatrixType A_type,
122-
MatrixType B_type,
123-
MatrixType C_type,
124-
MatrixType accType) override;
121+
UALError reorder(const Matrix& in,
122+
Matrix& out,
123+
MatrixType A_type,
124+
MatrixType B_type,
125+
MatrixType C_type,
126+
MatrixType accType,
127+
const GroupScaleParam* group_scale = nullptr) override;
125128

126129
UALError batch_gemm(std::vector<BatchGroup>& groups,
127130
MatrixType accType) override;

0 commit comments

Comments
 (0)