Skip to content

Commit e7c9a6c

Browse files
jing-baoqjia7
andauthored
[webgpu] make DP4AMatMulNBitsSmallMProgram shader template (#25025)
### Description This commit refactors the `DP4AMatMulNBitsSmallMProgram` to allow both `tile_size_k_vec` and `tile_size` to be configured. This change allows more flexibility for performance tuning without altering the core shader functionality. There is no functional change in this commit. ### Motivation and Context This is a preparatory change to enable `DP4AMatMulNBitsSmallMProgram` performance optimization work in subsequent commits. --------- Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
1 parent 0784e0a commit e7c9a6c

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -325,15 +325,19 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
325325
shader.AddInput("input_b", ShaderUsage::UseUniform);
326326
shader.AddInput("scales_b", ShaderUsage::UseUniform);
327327
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
328+
329+
ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4");
330+
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec_;
331+
ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count");
332+
328333
// This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads,
329334
// and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused).
330335
// For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows
331336
// in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a.
332337

333-
// 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time,
338+
// 1. Each workgroup handles tile_size_k_vec * k_vectorization_in_b (32) columns and num_concurrent_b_rows of matrix B at a time,
334339
// iterating over the columns to compute a partial dot product.
335340
// 2. Uses vec4 vectorization where each K represents 32 elements of matrix B
336-
constexpr uint32_t tile_size_k_vec = 16;
337341

338342
// 1. Workgroup Responsibility:
339343
// - Processes one row of matrix A
@@ -346,18 +350,19 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
346350
// - Iterates through columns accumulating results in inter_results
347351
// - Performs final reduction sum in inter_results for output
348352
shader.AdditionalImplementation() << " const tile_size = " << tile_size_ << "u;\n"
349-
<< " const tile_size_k_vec = " << tile_size_k_vec << "u;\n"
350-
<< " const double_tile_size_k_vec = " << 2 * tile_size_k_vec << "u;\n"
353+
<< " const tile_size_k_vec = " << tile_size_k_vec_ << "u;\n"
354+
<< " const double_tile_size_k_vec = " << 2 * tile_size_k_vec_ << "u;\n"
351355
// sub_tile_count is the number of concurrent b rows processed by the workgroup.
352-
<< " const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"
353-
<< " var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;\n";
356+
<< " const sub_tile_count = " << sub_tile_count << "u;\n";
354357

355358
shader.AdditionalImplementation() << CommonFunctions(nbits_)
356359
<< R"ADDNL_FN(
357-
// Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits.
358-
var<workgroup> tile_A : array<vec4<u32>, 32>;
359-
// Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128.
360-
var<workgroup> scale_A : array<output_element_t, 4>;
360+
var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;
361+
// Need 2 * tile_size_k_vec to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits.
362+
var<workgroup> tile_A : array<vec4<u32>, double_tile_size_k_vec>;
363+
// double_tile_size_k_vec * 16 / 128
364+
const scale_a_size_in_tile_a = double_tile_size_k_vec / 8;
365+
var<workgroup> scale_A : array<output_element_t, scale_a_size_in_tile_a>;
361366
fn loadSHMA(a_global: u32, kidx_v: u32, col: u32)
362367
{
363368
let k_offset = kidx_v + col;
@@ -366,7 +371,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
366371
}
367372
368373
tile_A[col] = input_a[a_global*uniforms.K16+k_offset];
369-
if (col < 4)
374+
if (col < scale_a_size_in_tile_a)
370375
{
371376
// kidx_v - covers 16 values of k in input_a
372377
scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col];
@@ -391,8 +396,6 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
391396
var own_a: vec4<u32> = tile_A[local_col * 2];
392397
var own_a1: vec4<u32> = tile_A[local_col * 2 + 1];
393398
var own_scale_a = scale_A[local_col / 4];
394-
var own_b = vec4<u32>(0);
395-
var own_b1 = vec4<u32>(0);
396399
let k_offset = kidx_v + local_col;
397400
// calculate intermediate results into inter_results.
398401
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
@@ -404,13 +407,13 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
404407
if (nbits_ == 4) {
405408
shader.MainFunctionBody() << R"MAIN_FN(
406409
let b_value = input_b[b_offset];
407-
own_b = DequantizedFrom4BitsTo8Bits(b_value.xy);
408-
own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw);
410+
let own_b = DequantizedFrom4BitsTo8Bits(b_value.xy);
411+
let own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw);
409412
)MAIN_FN";
410413
} else {
411414
shader.MainFunctionBody() << R"MAIN_FN(
412-
own_b = AlignWithZeroPoint(input_b[b_offset * 2]);
413-
own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]);
415+
let own_b = AlignWithZeroPoint(input_b[b_offset * 2]);
416+
let own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]);
414417
)MAIN_FN";
415418
}
416419
shader.MainFunctionBody() << R"MAIN_FN(
@@ -466,9 +469,11 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
466469
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
467470

468471
if (M < min_M_for_tile_optimization) {
469-
constexpr uint32_t kTileSize = 32;
470-
DP4AMatMulNBitsSmallMProgram mul_program{kTileSize, nbits};
471-
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
472+
uint32_t tile_size_k_vec = 16;
473+
uint32_t tile_size = 32;
474+
475+
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size, nbits};
476+
uint32_t num_N_tile = (N + tile_size - 1) / tile_size;
472477
mul_program.SetWorkgroupSize(128);
473478
mul_program.SetDispatchGroupSize(M * num_N_tile);
474479
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
@@ -477,7 +482,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
477482
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
478483
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile})
479484
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
480-
.CacheHint(nbits);
485+
.CacheHint(nbits, tile_size_k_vec, tile_size);
481486
return context.RunProgram(mul_program);
482487
}
483488

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
3737

3838
class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {
3939
public:
40-
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size, uint32_t nbits) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size), nbits_(nbits) {}
40+
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_k_vec_(tile_size_k_vec), tile_size_(tile_size), nbits_(nbits) {}
4141
Status GenerateShaderCode(ShaderHelper& sh) const override;
4242
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
4343
{"M", ProgramUniformVariableDataType::Uint32},
@@ -49,6 +49,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
4949
{"num_N_tile", ProgramUniformVariableDataType::Uint32});
5050

5151
private:
52+
uint32_t tile_size_k_vec_;
5253
uint32_t tile_size_;
5354
uint32_t nbits_;
5455
};

0 commit comments

Comments
 (0)