Skip to content

Commit 2ece1c1

Browse files
authored
webgpu: optimize Gemm and MatMul using subgroup feature (#26433)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent e40d5b7 commit 2ece1c1

File tree

12 files changed

+631
-40
lines changed

12 files changed

+631
-40
lines changed

onnxruntime/core/providers/webgpu/math/gemm.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "core/providers/webgpu/math/gemm.h"
55
#include "core/providers/webgpu/math/gemm_packed.h"
6+
#include "core/providers/webgpu/vendor/intel/math/gemm.h"
67

78
#include <vector>
89

@@ -147,6 +148,10 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
147148
return context.RunProgram(program);
148149
}
149150

151+
if (intel::CanApplyGemmIntel(context, M, N, K, transA_, transB_)) {
152+
return intel::ApplyGemmIntel(A, B, C, transA_, transB_, alpha_, beta_, context);
153+
}
154+
150155
return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context);
151156
}
152157

onnxruntime/core/providers/webgpu/math/gemm_packed.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
3131
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
3232
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
3333

34-
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_);
34+
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
3535
}
3636
if (is_vec4_) {
3737
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
@@ -45,7 +45,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
4545
}
4646

4747
const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type;
48-
MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type);
48+
MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type);
4949

5050
return Status::OK();
5151
}

onnxruntime/core/providers/webgpu/math/gemm_utils.cc

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
4949
shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n";
5050
}
5151
shader.AdditionalImplementation() << " " << activation_snippet << "\n"
52-
<< output.SetByIndices("coords", "value") << "\n";
52+
<< " " << output.SetByIndices("coords", "value") << "\n";
5353
}
5454

5555
void HandleMatMulWithSplitK(
@@ -127,60 +127,61 @@ void MatMulReadFnSource(ShaderHelper& shader,
127127
const ShaderVariableHelper& b,
128128
const ShaderIndicesHelper* batch_dims,
129129
bool transA,
130-
bool transB,
131-
bool is_vec4) {
132-
int components = is_vec4 ? 4 : 1;
130+
bool transB) {
131+
const int a_components = a.NumComponents();
133132
const std::string data_type = "output_element_t";
134-
const std::string type_string = MakeScalarOrVectorType(components, data_type);
133+
std::string type_string = MakeScalarOrVectorType(a_components, data_type);
135134

136135
shader.AdditionalImplementation()
137136
<< "fn mm_readA(batch: i32, row: i32, colIn: i32 "
138137
<< (batch_dims
139138
? ", batch_indices: batch_dims_indices_t"
140139
: "")
141-
<< ") -> " << type_string << " {\n "
142-
<< " var value = " << type_string << "(0);\n"
143-
<< " let col = colIn * " << components << ";\n";
140+
<< ") -> " << type_string << " {\n"
141+
<< " var value = " << type_string << "(0);\n"
142+
<< " let col = colIn * " << a_components << ";\n";
144143
if (transA) {
145-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n";
144+
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n";
146145
} else {
147-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n";
146+
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n";
148147
}
149-
shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n";
148+
shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n";
150149

151150
if (batch_dims) {
152-
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n";
151+
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ");
153152
}
154-
shader.AdditionalImplementation() << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
155-
<< a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n"
156-
<< " value = " << a.GetByIndices("a_indices") << ";\n"
157-
<< " }\n"
158-
<< " return value;\n"
153+
shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
154+
<< " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n"
155+
<< " value = " << a.GetByIndices("a_indices") << ";\n"
156+
<< " }\n"
157+
<< " return value;\n"
159158
<< "}\n\n";
160159

161160
// Add the mm_readB function
161+
const int b_components = b.NumComponents();
162+
type_string = MakeScalarOrVectorType(b_components, data_type);
162163
shader.AdditionalImplementation()
163164
<< "fn mm_readB(batch: i32, row: i32, colIn: i32 "
164165
<< (batch_dims
165166
? ", batch_indices: batch_dims_indices_t"
166167
: "")
167-
<< ") -> " << type_string << " {\n "
168-
<< " var value = " << type_string << "(0);\n"
169-
<< " let col = colIn * " << components << ";\n";
168+
<< ") -> " << type_string << " {\n"
169+
<< " var value = " << type_string << "(0);\n"
170+
<< " let col = colIn * " << b_components << ";\n";
170171

171172
if (transB) {
172-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n";
173+
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n";
173174
} else {
174-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n";
175+
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n";
175176
}
176177

177-
shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n"
178+
shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n"
178179
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices")
179-
<< b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
180-
<< b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n"
181-
<< " value = " << b.GetByIndices("b_indices") << ";\n"
182-
<< " }\n"
183-
<< " return value;\n"
180+
<< " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
181+
<< " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n"
182+
<< " value = " << b.GetByIndices("b_indices") << ";\n"
183+
<< " }\n"
184+
<< " return value;\n"
184185
<< "}\n\n";
185186
}
186187

@@ -189,19 +190,19 @@ void MatMulWriteFnSource(ShaderHelper& shader,
189190
const ShaderVariableHelper* bias,
190191
bool is_gemm,
191192
int c_components,
192-
int output_components,
193193
bool c_is_scalar,
194194
std::string activation_snippet,
195195
bool is_channels_last,
196196
bool use_split_k,
197197
ProgramVariableDataType output_variable_type) {
198+
const int output_components = output.NumComponents();
198199
shader.AdditionalImplementation()
199-
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";
200+
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) {\n";
200201

201202
shader.AdditionalImplementation() << " let col = colIn * " << output_components << ";\n";
202203

203-
shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
204-
<< " var value = valueIn; \n";
204+
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n"
205+
<< " var value = valueIn;\n";
205206

206207
if (use_split_k) {
207208
// Set output when MatMul is performed with Split-K.

onnxruntime/core/providers/webgpu/math/gemm_utils.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@ void MatMulReadFnSource(ShaderHelper& shader,
1313
const ShaderVariableHelper& b,
1414
const ShaderIndicesHelper* batch_dims,
1515
bool transA,
16-
bool transB,
17-
bool is_vec4);
16+
bool transB);
1817

1918
void MatMulWriteFnSource(ShaderHelper& shader,
2019
const ShaderVariableHelper& output,
2120
const ShaderVariableHelper* bias,
2221
bool is_gemm,
2322
int c_components,
24-
int output_components,
2523
bool c_is_scalar,
2624
std::string activation_snippet = "",
2725
bool is_channels_last = false,

onnxruntime/core/providers/webgpu/math/matmul.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "core/providers/webgpu/webgpu_supported_types.h"
99
#include "core/providers/webgpu/nn/fuse_utils.h"
1010
#include "core/providers/webgpu/data_transfer.h"
11+
#include "core/providers/webgpu/vendor/intel/math/matmul.h"
1112
#include "core/providers/webgpu/webgpu_utils.h"
1213

1314
namespace onnxruntime {
@@ -163,6 +164,10 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
163164
inputs.push_back(bias);
164165
}
165166

167+
if (intel::CanApplyMatMulIntel(context, helper.M(), helper.N(), helper.K())) {
168+
return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor);
169+
}
170+
166171
return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
167172
}
168173

onnxruntime/core/providers/webgpu/math/matmul_packed.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
3333
std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");
3434
ProgramVariableDataType output_var_type = this->Outputs()[0].var_type;
3535
// declare the read and write functions
36-
MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_);
37-
MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type);
36+
MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false);
37+
MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type);
3838
std::string data_type = "a_element_t";
3939
// generate the main function
4040
if (is_vec4_) {
@@ -65,7 +65,7 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
6565
// `use_split_k` is true only when we do the actual MatMul with Split-K.
6666
const uint32_t bias_components = output_components_;
6767
MatMulWriteFnSource(
68-
shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_,
68+
shader, output, bias, is_gemm_, bias_components, bias_is_scalar_,
6969
/*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false);
7070

7171
shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n";
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/vendor/intel/math/gemm.h"
5+
#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h"
6+
#include "core/providers/webgpu/math/gemm_utils.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
namespace intel {
11+
12+
Status GemmSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const {
13+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform |
14+
ShaderUsage::UseValueTypeAlias |
15+
ShaderUsage::UseElementTypeAlias);
16+
17+
if (need_handle_matmul_) {
18+
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
19+
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
20+
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
21+
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
22+
23+
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
24+
}
25+
26+
ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, nullptr, is_vec4_, transA_, transB_,
27+
alpha_, need_handle_matmul_));
28+
const ShaderVariableHelper* c = nullptr;
29+
if (need_handle_bias_) {
30+
c = &shader.AddInput("c", ShaderUsage::UseUniform);
31+
}
32+
MatMulWriteFnSource(shader, output, c, true, c_components_, c_is_scalar_);
33+
34+
return Status::OK();
35+
}
36+
37+
bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) {
38+
return CanApplySubgroup(context, M, N, K, transA, transB);
39+
}
40+
41+
Status ApplyGemmIntel(const Tensor* a,
42+
const Tensor* b,
43+
const Tensor* c,
44+
bool transA,
45+
bool transB,
46+
float alpha,
47+
float beta,
48+
ComputeContext& context) {
49+
const auto& a_shape = a->Shape();
50+
const auto& b_shape = b->Shape();
51+
52+
uint32_t M = onnxruntime::narrow<uint32_t>(transA ? a_shape[1] : a_shape[0]);
53+
uint32_t K = onnxruntime::narrow<uint32_t>(transA ? a_shape[0] : a_shape[1]);
54+
uint32_t N = onnxruntime::narrow<uint32_t>(transB ? b_shape[0] : b_shape[1]);
55+
56+
std::vector<int64_t> output_dims{M, N};
57+
auto* y = context.Output(0, output_dims);
58+
int64_t output_size = y->Shape().Size();
59+
60+
if (output_size == 0) {
61+
return Status::OK();
62+
}
63+
64+
// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
65+
bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0;
66+
bool need_handle_bias = c && beta;
67+
68+
const bool is_vec4 = b_shape[1] % 4 == 0;
69+
// Components for A, B
70+
int a_components = 1;
71+
int b_components = is_vec4 ? 4 : 1;
72+
// Components for Y
73+
int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1;
74+
// Components for C.
75+
int c_components = 1;
76+
77+
bool c_is_scalar = false;
78+
if (need_handle_bias) {
79+
const auto& c_shape = c->Shape();
80+
int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1];
81+
// `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent.
82+
// So we use vec4 for C when its last dimension is N, and the output is also a vec4.
83+
c_components = (c_last_dim == N && output_components == 4) ? 4 : 1;
84+
c_is_scalar = c_shape.Size() == 1;
85+
}
86+
87+
InlinedVector<int64_t> elements_per_thread = InlinedVector<int64_t>({4, intel::ElementsPerThreadY(is_vec4, M), 1});
88+
const uint32_t dispatch_x = narrow<uint32_t>((N + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) /
89+
(kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0]));
90+
const uint32_t dispatch_y = narrow<uint32_t>((M + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) /
91+
(kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1]));
92+
93+
GemmSubgroupProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar,
94+
is_vec4, elements_per_thread};
95+
96+
if (need_handle_matmul) {
97+
program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
98+
{b, ProgramTensorMetadataDependency::TypeAndRank, b_components}});
99+
}
100+
101+
if (need_handle_bias) {
102+
program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components});
103+
}
104+
105+
program.CacheHint(alpha, transA, transB, c_is_scalar, absl::StrJoin(elements_per_thread, "-"))
106+
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
107+
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1)
108+
.SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1)
109+
.AddUniformVariables({{alpha},
110+
{beta},
111+
{M}, /* dim_a_outer */
112+
{N}, /* dim_b_outer */
113+
{K}} /*dim_inner */
114+
);
115+
116+
return context.RunProgram(program);
117+
}
118+
119+
} // namespace intel
120+
} // namespace webgpu
121+
} // namespace onnxruntime

0 commit comments

Comments
 (0)