Skip to content

Commit 535f1f3

Browse files
authored
[webgpu] Split large inputs into smaller buffers to bypass maxStorageBufferBindingSize limit (#25962)
### Description When an input is bigger than maxStorageBufferBindingSize, use multiple binding entries for it. We refine the implement for `getByOffset`/`setByOffset` so that let's say, `input_b` is 257MB, but maxStorageBufferBindingSize is 256MB, we can use `b.getByOffset(offset)` to get the correct element and no need to care about the different binding entry. Actually, it will generate shader code like this. ``` var<storage, read> input_b: array<vec4<u32>>; // [0, 256MB) of input_b var<storage, read> input_b1: array<vec4<u32>>; // [256MB, 257MB) of input_b ``` ### Motivation and Context QC's maxStorageBufferBindingSize is 256MB, which is not enough for phi-4 model. So for QC, we customized a new phi4 model which use `slice` op to split the big matrix. That means we need to keep two different phi4 model for different platform. ### For reviewers The core logic is located - Shader side: - `shader_helper.cc`. In shader, use more`@group(0) @binding(....` matched the actual buffer numbers. - `shader_variable.cc`. Implement `set_xxx_by_offset(global_offset, value)` and `get_xxx_by_offset(global_offset)` shader helper function, which will be used when using `setByOffset`/`getByOffset` and the input exceed the maxstoragebuffersize. - WebGPU API side: - `webgpu_context.cc`. In WebGPU API, use more group entry matched the actual buffer numbers.
1 parent a60c307 commit 535f1f3

23 files changed

+507
-200
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#param has_zero_points
77
#param is_qualcomm
88

9+
#use .getByOffset .setByOffset
10+
911
#include "quantization/dp4a_matmul_common.wgsl.template"
1012

1113
// This shader implements co-operative matrix multiply. The key idea here is to
@@ -57,11 +59,11 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
5759
{
5860
return;
5961
}
60-
tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col];
62+
tile_A[col][row] = a.getByOffset(a_global*uniforms.K16+kidx_v+col);
6163
if (col == 0)
6264
{
6365
// kidx_v - covers 16 values of k
64-
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8];
66+
scale_A[row] = scales_a.getByOffset(a_global*(uniforms.K/128) + kidx_v/8);
6567
}
6668
}
6769

@@ -74,14 +76,14 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
7476
return;
7577
}
7678

77-
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
79+
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
7880
let block_idx = kidx_v/(block_size/16);
7981
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
8082
tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero);
8183
if (col == 0)
8284
{
8385
// kidx_v - each kidx_v covers 16 values of k
84-
scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx];
86+
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
8587
}
8688
}
8789
#endif
@@ -95,13 +97,13 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
9597
return;
9698
}
9799

98-
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
100+
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
99101
tile_B[col][row] = AlignWithZeroPoint(b_value);
100102
if (col == 0)
101103
{
102104
// kidx_v - each kidx_v covers 16 values of k
103105
let block_idx = kidx_v/(block_size/16);
104-
scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx];
106+
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
105107
#if has_zero_points
106108
zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
107109
#endif
@@ -117,10 +119,10 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
117119
{
118120
return;
119121
}
120-
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
122+
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
121123
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
122124
let block_idx = kidx_v/(block_size/16);
123-
scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx];
125+
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
124126
}
125127
#endif
126128

@@ -362,15 +364,15 @@ $MAIN {
362364
if (a_global < uniforms.M && b_global < uniforms.N)
363365
{
364366
#if is_qualcomm
365-
output[output_idx] = vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]);
366-
output[output_idx+1] = vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]);
367-
output[output_idx+2] = vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]);
368-
output[output_idx+3] = vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]);
367+
output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]));
368+
output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]));
369+
output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]));
370+
output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]));
369371
#else
370-
output[output_idx] = lane_output1;
371-
output[output_idx+1] = lane_output2;
372-
output[output_idx+2] = lane_output3;
373-
output[output_idx+3] = lane_output4;
372+
output.setByOffset(output_idx, lane_output1);
373+
output.setByOffset(output_idx+1, lane_output2);
374+
output.setByOffset(output_idx+2, lane_output3);
375+
output.setByOffset(output_idx+3, lane_output4);
374376
#endif
375377
}
376378
} // MAIN

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,47 @@ namespace contrib {
1010
namespace webgpu {
1111

1212
Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const {
13-
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
14-
shader.AddOutput("output", ShaderUsage::UseUniform);
15-
shader.AddOutput("scales", ShaderUsage::UseUniform);
16-
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_quantize.wgsl.template");
13+
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
14+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
15+
const auto& scales = shader.AddOutput("scales", ShaderUsage::UseUniform);
16+
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_quantize.wgsl.template",
17+
WGSL_TEMPLATE_VARIABLE(a, a),
18+
WGSL_TEMPLATE_VARIABLE(output, output),
19+
WGSL_TEMPLATE_VARIABLE(scales, scales));
1720
}
1821

1922
Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
20-
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
21-
shader.AddInput("scales_a", ShaderUsage::UseUniform);
22-
shader.AddInput("input_b", ShaderUsage::UseUniform);
23-
shader.AddInput("scales_b", ShaderUsage::UseUniform);
23+
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
24+
const auto& scales_a = shader.AddInput("scales_a", ShaderUsage::UseUniform);
25+
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform);
26+
const auto& scales_b = shader.AddInput("scales_b", ShaderUsage::UseUniform);
2427
if (has_zero_points_) {
2528
shader.AddInput("zero_points", ShaderUsage::UseUniform);
2629
}
27-
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
30+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
2831
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template",
2932
WGSL_TEMPLATE_PARAMETER(block_size, block_size_),
3033
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
3134
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
3235
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
33-
WGSL_TEMPLATE_PARAMETER(output_type_i32, true));
36+
WGSL_TEMPLATE_PARAMETER(output_type_i32, true),
37+
WGSL_TEMPLATE_VARIABLE(a, a),
38+
WGSL_TEMPLATE_VARIABLE(b, b),
39+
WGSL_TEMPLATE_VARIABLE(output, output),
40+
WGSL_TEMPLATE_VARIABLE(scales_a, scales_a),
41+
WGSL_TEMPLATE_VARIABLE(scales_b, scales_b));
3442
}
3543

3644
// scale_A components = 1, b components = 4, output components = 1
3745
Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const {
38-
shader.AddInput("input_a", ShaderUsage::UseUniform);
39-
shader.AddInput("scales_a", ShaderUsage::UseUniform);
40-
shader.AddInput("input_b", ShaderUsage::UseUniform);
41-
shader.AddInput("scales_b", ShaderUsage::UseUniform);
46+
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform);
47+
const auto& scales_a = shader.AddInput("scales_a", ShaderUsage::UseUniform);
48+
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform);
49+
const auto& scales_b = shader.AddInput("scales_b", ShaderUsage::UseUniform);
4250
if (has_zero_points_) {
4351
shader.AddInput("zero_points", ShaderUsage::UseUniform);
4452
}
45-
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
53+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
4654

4755
ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4");
4856
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec_;
@@ -55,7 +63,12 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
5563
WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_),
5664
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
5765
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
58-
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_));
66+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_),
67+
WGSL_TEMPLATE_VARIABLE(a, a),
68+
WGSL_TEMPLATE_VARIABLE(b, b),
69+
WGSL_TEMPLATE_VARIABLE(output, output),
70+
WGSL_TEMPLATE_VARIABLE(scales_a, scales_a),
71+
WGSL_TEMPLATE_VARIABLE(scales_b, scales_b));
5972
}
6073

6174
Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#param n_bits
99
#param has_zero_points
1010

11+
#use .getByOffset .setByOffset
12+
13+
1114
#include "quantization/dp4a_matmul_common.wgsl.template"
1215

1316
// This algorithm works to compute dot product of k in parallel, by processing k at each step amongst tile_size_k_vec threads,
@@ -47,11 +50,11 @@ fn loadSHMA(a_global: u32, kidx_v: u32, col: u32)
4750
return;
4851
}
4952

50-
tile_A[col] = input_a[a_global*uniforms.K16+k_offset];
53+
tile_A[col] = a.getByOffset(a_global*uniforms.K16+k_offset);
5154
if (col < scale_a_size_in_tile_a)
5255
{
5356
// kidx_v - covers 16 values of k in input_a
54-
scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col];
57+
scale_A[col] = scales_a.getByOffset(a_global*(uniforms.K/128) + kidx_v/8 + col);
5558
}
5659
}
5760

@@ -70,7 +73,7 @@ $MAIN {
7073
#endif
7174
#if single_scale_weights
7275
let zero = mm_read_zero(0, 0, uniforms.N, uniforms.zero_blocks_per_col);
73-
let own_scale_b = scales_b[0];
76+
let own_scale_b = scales_b.getByOffset(0);
7477
#endif
7578

7679
for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec)
@@ -95,24 +98,24 @@ $MAIN {
9598
let b_offset = b_global * uniforms.K32 + k_offset;
9699
#if !single_scale_weights
97100
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
98-
let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + block_idx];
101+
let own_scale_b = scales_b.getByOffset(b_global * uniforms.K / uniforms.block_size + block_idx);
99102
#endif
100103
#if n_bits == 4
101-
let b_value = input_b[b_offset];
104+
let b_value = b.getByOffset(b_offset);
102105
let own_b = DequantizedFrom4BitsTo8Bits(b_value.xy, zero);
103106
let own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw, zero);
104107
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
105108
#elif n_bits == 8
106-
let own_b = AlignWithZeroPoint(input_b[b_offset * 2]);
107-
let own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]);
109+
let own_b = AlignWithZeroPoint(b.getByOffset(b_offset * 2));
110+
let own_b1 = AlignWithZeroPoint(b.getByOffset(b_offset * 2 + 1));
108111
#if has_zero_points
109112
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b, zero);
110113
#else
111114
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
112115
#endif
113116

114117
#elif n_bits == 2
115-
let b_value = input_b[b_offset];
118+
let b_value = b.getByOffset(b_offset);
116119
let own_b = DequantizedFrom2BitsTo8Bits(b_value.x);
117120
let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y);
118121
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
@@ -131,7 +134,7 @@ $MAIN {
131134
let b_global = b_global_base + local_idx;
132135
let output_idx = a_global * uniforms.N + b_global;
133136
if (b_global < uniforms.N) {
134-
output[output_idx] = output_value;
137+
output.setByOffset(output_idx, output_value);
135138
}
136139
}
137140
} // MAIN

onnxruntime/contrib_ops/webgpu/quantization/dp4a_quantize.wgsl.template

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// Quantizes input matrix A for DP4A computation
66
// This shader quantizes float values to 8-bit signed integers using pack4x8snorm
77

8+
#use .getByOffset .setByOffset
9+
810
var<workgroup> a_values : array<array<input_a_value_t, 32>, 2>;
911
var<workgroup> max_values : array<input_a_value_t, 4>;
1012

@@ -13,7 +15,7 @@ fn readInput(offset: u32) -> input_a_value_t
1315
if (offset >= uniforms.output_size) {
1416
return input_a_value_t(0);
1517
}
16-
return input_a[offset];
18+
return a.getByOffset(offset);
1719
}
1820

1921
$MAIN {
@@ -26,11 +28,11 @@ $MAIN {
2628
let max_temp = max(max_val.xy, max_val.zw);
2729
let scale = max(max_temp[0], max_temp[1]);
2830
let norm_a = local_a/scale;
29-
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
31+
output.setByOffset(global_idx, pack4x8snorm(vec4<f32>(norm_a)));
3032
if (local_idx % 32 == 0)
3133
{
3234
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
33-
scales[workgroup_idx * 2 + local_idx / 32] = scale/127;
35+
scales.setByOffset(workgroup_idx * 2 + local_idx / 32, scale/127);
3436
}
3537
} else if (sg_size == 16) {
3638
let local_a = readInput(global_idx);
@@ -53,11 +55,11 @@ $MAIN {
5355
let max_temp = max(max_val.xy, max_val.zw);
5456
let scale = max(max_temp[0], max_temp[1]);
5557
let norm_a = local_a/scale;
56-
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
58+
output.setByOffset(global_idx, pack4x8snorm(vec4<f32>(norm_a)));
5759
if (local_idx % 32 == 0)
5860
{
5961
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
60-
scales[workgroup_idx * 2 + local_idx / 32] = scale/127;
62+
scales.setByOffset(workgroup_idx * 2 + local_idx / 32, scale/127);
6163
}
6264
} else {
6365
let local_row = local_idx / 32u;
@@ -78,11 +80,11 @@ $MAIN {
7880
let max_temp = max(max_val.xy, max_val.zw);
7981
let scale = max(max_temp[0], max_temp[1]);
8082
let norm_a = a_values[local_row][local_col]/scale;
81-
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
83+
output.setByOffset(global_idx, pack4x8snorm(vec4<f32>(norm_a)));
8284
if (local_col == 0u)
8385
{
8486
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
85-
scales[workgroup_idx * 2 + local_row] = scale/127;
87+
scales.setByOffset(workgroup_idx * 2 + local_row, scale/127);
8688
}
8789
}
8890
}

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ ONNX_OPERATOR_KERNEL_EX(
4242
MatMulNBits);
4343

4444
Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const {
45-
shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
46-
shader.AddInput("input_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
47-
shader.AddInput("scales", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
45+
const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
46+
const auto& b = shader.AddInput("input_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
47+
const auto& scales = shader.AddInput("scales", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
4848
if (has_zero_points_) {
4949
shader.AddInput("zero_points", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
5050
}
51-
shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
51+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
5252

5353
const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY();
5454
ORT_ENFORCE(tile_m_ == workgroup_size / 8, "tile_m must be workgroup_size / 8.");
@@ -59,18 +59,22 @@ Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) cons
5959
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
6060
WGSL_TEMPLATE_PARAMETER(nbits, nbits_),
6161
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
62-
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_));
62+
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_),
63+
WGSL_TEMPLATE_VARIABLE(a, a),
64+
WGSL_TEMPLATE_VARIABLE(b, b),
65+
WGSL_TEMPLATE_VARIABLE(output, output),
66+
WGSL_TEMPLATE_VARIABLE(scales, scales));
6367
}
6468

6569
// Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm.
6670
Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
6771
const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias);
6872
const auto& b = shader.AddInput("input_b");
69-
shader.AddInput("scales_b");
73+
const auto& scales_b = shader.AddInput("scales_b");
7074
if (has_zero_points_) {
7175
shader.AddInput("zero_points", ShaderUsage::UseUniform);
7276
}
73-
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);
77+
const auto& output = shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);
7478

7579
const uint32_t components_a = a.NumComponents();
7680
const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includes 4 uint8.
@@ -92,7 +96,11 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
9296
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
9397
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
9498
WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k),
95-
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec));
99+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
100+
WGSL_TEMPLATE_VARIABLE(a, a),
101+
WGSL_TEMPLATE_VARIABLE(b, b),
102+
WGSL_TEMPLATE_VARIABLE(output, output),
103+
WGSL_TEMPLATE_VARIABLE(scales_b, scales_b));
96104
}
97105

98106
Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {

0 commit comments

Comments
 (0)