Skip to content

Commit 34a0b15

Browse files
authored
[webgpu] Support any batch size for dp4a matmul path (#26884)
This pull request adds support for batched matrix multiplication in the DP4A quantized matmul WebGPU kernels and their associated C++ code and tests. The changes update the kernel code, tensor shapes, dispatch logic, and test infrastructure to properly handle a `batch_count` greater than 1, enabling efficient batched execution.
1 parent 5bc10a3 commit 34a0b15

File tree

6 files changed

+137
-41
lines changed

6 files changed

+137
-41
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ var<workgroup> scale_B : array<output_element_t, tile_size>;
5454
var<workgroup> zeroes : array<i32, tile_size>;
5555
#endif
5656

57-
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
57+
fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32)
5858
{
5959
let a_global = a_global_base + row;
6060
if (a_global >= uniforms.M)
6161
{
6262
return;
6363
}
64-
tile_A[col][row] = a.getByOffset(a_global*uniforms.K16+kidx_v+col);
64+
tile_A[col][row] = a.getByOffset(batch*uniforms.M*uniforms.K16+a_global*uniforms.K16+kidx_v+col);
6565
if (col == 0)
6666
{
6767
// kidx_v - covers 16 values of k
68-
scale_A[row] = scales_a.getByOffset(a_global*(uniforms.K/128) + kidx_v/8);
68+
scale_A[row] = scales_a.getByOffset(batch*uniforms.M*(uniforms.K/128) + a_global*(uniforms.K/128) + kidx_v/8);
6969
}
7070
}
7171

@@ -154,7 +154,11 @@ $MAIN {
154154
#endif
155155
// During the load phase we use all 256 threads to load 64 rows of A/B.
156156
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
157-
let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size;
157+
let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile);
158+
if (batch >= uniforms.batch_count) {
159+
return;
160+
}
161+
let a_global_base = u32((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * tile_size;
158162
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
159163
let load_AorB = u32(local_idx/128);
160164
let load_row = u32((local_idx%128)/2);
@@ -199,7 +203,7 @@ $MAIN {
199203
// Load Phase: Populate shared memory for the workgroup.
200204
if (load_AorB == 0)
201205
{
202-
loadSHMA(a_global_base, kidx_v, load_row, load_col);
206+
loadSHMA(batch, a_global_base, kidx_v, load_row, load_col);
203207
}
204208
else
205209
{
@@ -380,7 +384,7 @@ $MAIN {
380384

381385
let a_global = a_global_base + base_A + a_idx;
382386
let b_global = b_global_base + base_B;
383-
let output_idx = ((a_global) * uniforms.N + b_global)/4;
387+
let output_idx = (batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global)/4;
384388
#if has_bias
385389
#if has_weight_idx
386390
let b_bias_offset = uniforms.weight_idx * uniforms.N;

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
8383

8484
Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
8585
const Tensor* zero_points, const Tensor* bias,
86+
uint32_t batch_count,
8687
uint32_t M,
8788
uint32_t N,
8889
uint32_t K,
@@ -101,15 +102,15 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
101102
DP4AMatMulQuantizeProgram quantize_program;
102103
quantize_program.SetWorkgroupSize(64);
103104
uint32_t tile_size = 64 * kVec4Components;
104-
quantize_program.SetDispatchGroupSize((M * K + tile_size - 1) / tile_size, 1, 1);
105-
TensorShape a_quant_shape{1, M, K / kU32Components};
105+
quantize_program.SetDispatchGroupSize((batch_count * M * K + tile_size - 1) / tile_size, 1, 1);
106+
TensorShape a_quant_shape{batch_count, M, K / kU32Components};
106107
Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), a_quant_shape);
107-
TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA});
108+
TensorShapeVector a_scales_dims({batch_count, 1, M, K / kBlockSizeA});
108109
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
109110
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}})
110111
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1},
111112
{&a_scale, ProgramTensorMetadataDependency::Rank, 1}})
112-
.AddUniformVariable({M * K / kU32Components});
113+
.AddUniformVariable({batch_count * M * K / kU32Components});
113114
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
114115

115116
const bool has_zero_points = zero_points != nullptr;
@@ -128,12 +129,12 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
128129
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights};
129130
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
130131
mul_program.SetWorkgroupSize(128);
131-
mul_program.SetDispatchGroupSize(M * num_N_tile);
132+
mul_program.SetDispatchGroupSize(batch_count * M * num_N_tile);
132133
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
133134
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
134135
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(b_components * kU32Components)},
135136
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
136-
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
137+
.AddUniformVariables({batch_count, M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
137138
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
138139
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx);
139140
if (has_zero_points) {
@@ -146,22 +147,24 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
146147
}
147148

148149
constexpr uint32_t kTileSize = 64;
149-
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
150+
TensorShape reshaped_y_shape{batch_count, M, N / kVec4Components};
150151
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
151152
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
152153
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
153154
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm};
154155
mul_program.SetWorkgroupSize(256);
155-
mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile);
156+
mul_program.SetDispatchGroupSize(batch_count * num_M_tile * num_N_tile);
156157
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
157158
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
158159
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>((nbits / 2) * kU32Components)},
159160
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
160-
.AddUniformVariables({{static_cast<uint32_t>(M)},
161+
.AddUniformVariables({{static_cast<uint32_t>(batch_count)},
162+
{static_cast<uint32_t>(M)},
161163
{static_cast<uint32_t>(N)},
162164
{static_cast<uint32_t>(K)},
163165
{static_cast<uint32_t>(K / 8)},
164166
{static_cast<uint32_t>(K / 16)},
167+
{num_M_tile},
165168
{num_N_tile},
166169
{zero_blocks_per_col},
167170
{weight_index}})
@@ -179,7 +182,6 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
179182
bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
180183
uint64_t accuracy_level,
181184
uint32_t block_size,
182-
uint32_t batch_count,
183185
uint32_t N,
184186
uint32_t K,
185187
uint32_t components_k) {
@@ -189,7 +191,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
189191
bool use_dp4a = context.HasFeature(wgpu::FeatureName::Subgroups) &&
190192
context.AdapterInfo().vendor != std::string_view{"apple"};
191193
return (accuracy_level == 4 && block_size % 32 == 0 &&
192-
batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 &&
194+
components_k == 4 && K % 128 == 0 && N % 16 == 0 &&
193195
use_dp4a);
194196
}
195197

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
3232
is_qualcomm_(is_qualcomm) {}
3333
Status GenerateShaderCode(ShaderHelper& sh) const override;
3434
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
35+
{"batch_count", ProgramUniformVariableDataType::Uint32},
3536
{"M", ProgramUniformVariableDataType::Uint32},
3637
{"N", ProgramUniformVariableDataType::Uint32},
3738
{"K", ProgramUniformVariableDataType::Uint32},
3839
{"K8", ProgramUniformVariableDataType::Uint32},
3940
{"K16", ProgramUniformVariableDataType::Uint32},
41+
{"num_M_tile", ProgramUniformVariableDataType::Uint32},
4042
{"num_N_tile", ProgramUniformVariableDataType::Uint32},
4143
{"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32},
4244
{"weight_idx", ProgramUniformVariableDataType::Uint32});
@@ -64,6 +66,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
6466
single_scale_weights_(single_scale_weights) {}
6567
Status GenerateShaderCode(ShaderHelper& sh) const override;
6668
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
69+
{"batch_count", ProgramUniformVariableDataType::Uint32},
6770
{"M", ProgramUniformVariableDataType::Uint32},
6871
{"N", ProgramUniformVariableDataType::Uint32},
6972
{"K", ProgramUniformVariableDataType::Uint32},
@@ -86,6 +89,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
8689

8790
Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
8891
const Tensor* zero_points, const Tensor* bias,
92+
uint32_t batch_count,
8993
uint32_t M,
9094
uint32_t N,
9195
uint32_t K,
@@ -100,7 +104,6 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
100104
bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
101105
uint64_t accuracy_level,
102106
uint32_t block_size,
103-
uint32_t batch_count,
104107
uint32_t N,
105108
uint32_t K,
106109
uint32_t components_k);

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,27 @@ var<workgroup> tile_A : array<vec4<u32>, double_tile_size_k_vec>;
4545
const scale_a_size_in_tile_a = double_tile_size_k_vec / 8;
4646
var<workgroup> scale_A : array<output_element_t, scale_a_size_in_tile_a>;
4747

48-
fn loadSHMA(a_global: u32, kidx_v: u32, col: u32)
48+
fn loadSHMA(batch: u32, a_global: u32, kidx_v: u32, col: u32)
4949
{
5050
let k_offset = kidx_v + col;
5151
if (k_offset >= uniforms.K16) {
5252
return;
5353
}
5454

55-
tile_A[col] = a.getByOffset(a_global*uniforms.K16+k_offset);
55+
tile_A[col] = a.getByOffset(batch*uniforms.M*uniforms.K16+a_global*uniforms.K16+k_offset);
5656
if (col < scale_a_size_in_tile_a)
5757
{
5858
// kidx_v - covers 16 values of k in input_a
59-
scale_A[col] = scales_a.getByOffset(a_global*(uniforms.K/128) + kidx_v/8 + col);
59+
scale_A[col] = scales_a.getByOffset(batch*uniforms.M*(uniforms.K/128) + a_global*(uniforms.K/128) + kidx_v/8 + col);
6060
}
6161
}
6262

6363
$MAIN {
64-
let a_global = u32(workgroup_idx / uniforms.num_N_tile);
64+
let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile);
65+
if (batch >= uniforms.batch_count) {
66+
return;
67+
}
68+
let a_global = u32((workgroup_idx / uniforms.num_N_tile) % uniforms.M);
6569
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
6670
// Handle each workgroup threads as a block of [sub_tile_count][tile_size_k_vec]
6771
let local_col = local_idx % tile_size_k_vec;
@@ -95,7 +99,7 @@ $MAIN {
9599
// Load Phase: Populate shared memory for the workgroup.
96100
if (local_idx < double_tile_size_k_vec)
97101
{
98-
loadSHMA(a_global, kidx_v * 2, local_idx);
102+
loadSHMA(batch, a_global, kidx_v * 2, local_idx);
99103
}
100104
workgroupBarrier();
101105
var own_a: vec4<u32> = tile_A[local_col * 2];
@@ -153,7 +157,7 @@ $MAIN {
153157
output_value += inter_results[local_idx][b];
154158
}
155159
let b_global = b_global_base + local_idx;
156-
let output_idx = a_global * uniforms.N + b_global;
160+
let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global;
157161
if (b_global < uniforms.N) {
158162
#if has_bias
159163
let bias_value = bias[b_global + b_bias_offset];

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
220220

221221
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
222222
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
223-
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, components_a)) {
224-
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
223+
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
224+
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
225225
}
226226

227227
// WideTileProgram

0 commit comments

Comments
 (0)