Skip to content

Commit b1e5290

Browse files
committed
[webgpu] im2col matmul
1 parent d55ade0 commit b1e5290

File tree

5 files changed

+534
-7
lines changed

5 files changed

+534
-7
lines changed

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33
#include "core/providers/webgpu/nn/conv.h"
44
#include "core/providers/webgpu/nn/conv2d_mm.h"
5+
#include "core/providers/webgpu/nn/im2col_matmul.h"
56
#include "core/providers/webgpu/shader_helper.h"
67
#include "core/providers/webgpu/webgpu_supported_types.h"
78
#include "core/providers/webgpu/tensor/transpose.h"
@@ -99,10 +100,34 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
99100
modified_input_output_shapes.push_back(bias->Shape());
100101
}
101102
modified_input_output_shapes.push_back(TensorShape(output_shape_vector));
103+
104+
const auto input_height = input_shape[is_channels_last ? 1 : 2];
105+
const auto input_width = input_shape[is_channels_last ? 2 : 3];
106+
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
107+
const auto kernel_height = kernel_shape[2];
108+
const auto kernel_width = kernel_shape[3];
109+
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
110+
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];
111+
102112
uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0;
103113
auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2;
104114
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
105115
std::vector<uint32_t> updated_pads{pad0, pad1};
116+
117+
if (CanApplyIm2ColMatMulProgram(context,
118+
is_channels_last,
119+
activation_.activation_kind_,
120+
kernel_shape,
121+
conv_attrs_.auto_pad,
122+
onnxruntime::narrow<uint32_t>(conv_attrs_.group))) {
123+
return ApplyIm2ColMatMulProgram(context,
124+
is_channels_last,
125+
dilations,
126+
pads,
127+
strides,
128+
output);
129+
}
130+
106131
if (conv_attrs_.group > 1) {
107132
Tensor transposed_kernel;
108133
if (is_channels_last) {
@@ -128,13 +153,6 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
128153
}
129154
return context.RunProgram(program);
130155
}
131-
const auto input_height = input_shape[is_channels_last ? 1 : 2];
132-
const auto input_width = input_shape[is_channels_last ? 2 : 3];
133-
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
134-
const auto kernel_height = kernel_shape[2];
135-
const auto kernel_width = kernel_shape[3];
136-
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
137-
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];
138156

139157
const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
140158
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#include <string>
4+
#include <vector>
5+
#include <iterator>
6+
#include <algorithm>
7+
8+
#include "core/providers/webgpu/shader_helper.h"
9+
#include "core/providers/webgpu/webgpu_supported_types.h"
10+
#include "core/providers/webgpu/webgpu_utils.h"
11+
#include "core/providers/webgpu/nn/im2col_matmul.h"
12+
#include "core/providers/webgpu/nn/activation_util.h"
13+
14+
namespace onnxruntime {
15+
namespace webgpu {
16+
17+
namespace {
18+
19+
template <typename T>
20+
inline T ceil_div(T numerator, T denominator) {
21+
return (numerator + denominator - 1) / denominator;
22+
}
23+
24+
std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) {
25+
// Define a list of preferred (tile_m, tile_n) pairs in descending order of preference.
26+
const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = {
27+
std::make_pair(32, 64),
28+
std::make_pair(16, 64),
29+
};
30+
31+
for (const auto& tile_pair : kTileSizes) {
32+
const uint32_t tile_m = tile_pair.first;
33+
const uint32_t tile_n = tile_pair.second;
34+
35+
const uint32_t dispatch_m = ceil_div(im2col_m, tile_m);
36+
const uint32_t dispatch_n = ceil_div(im2col_n, tile_n);
37+
const uint32_t dispatch = dispatch_m * dispatch_n;
38+
39+
if (dispatch >= 128) {
40+
return tile_pair;
41+
}
42+
}
43+
44+
// If none of the tile sizes meet the dispatch >=128 requirement,
45+
return kTileSizes.back();
46+
}
47+
48+
// Add support for more devices and tile size configurations.
49+
bool IsDeviceSupported(ComputeContext& context) {
50+
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();
51+
52+
if (adapter_info.vendor == std::string_view("intel")) {
53+
if (adapter_info.architecture == std::string_view("xe-2lpg")) {
54+
return true;
55+
}
56+
}
57+
58+
return false;
59+
}
60+
61+
} // namespace
62+
63+
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
64+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
65+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
66+
67+
return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
68+
WGSL_TEMPLATE_VARIABLE(output, output),
69+
WGSL_TEMPLATE_VARIABLE(src, src));
70+
}
71+
72+
Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
73+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
74+
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
75+
if (has_bias_) {
76+
shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
77+
}
78+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
79+
80+
ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32.");
81+
ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64.");
82+
83+
return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template",
84+
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
85+
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
86+
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_),
87+
WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_),
88+
WGSL_TEMPLATE_VARIABLE(output, output),
89+
WGSL_TEMPLATE_VARIABLE(src, src),
90+
WGSL_TEMPLATE_VARIABLE(weight, weight));
91+
}
92+
93+
Status ApplyIm2ColMatMulProgram(ComputeContext& context,
94+
bool is_channels_last,
95+
const std::vector<uint32_t>& dilations,
96+
const std::vector<uint32_t>& pads,
97+
const std::vector<uint32_t>& strides,
98+
Tensor* output) {
99+
const auto* input = context.Input<Tensor>(0);
100+
const auto* kernel = context.Input<Tensor>(1);
101+
const bool has_bias = context.InputCount() > 2;
102+
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
103+
104+
// Transpose OIHW Weight to OHWI
105+
TensorShape kernel_shape = kernel->Shape();
106+
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(kernel_shape[0]);
107+
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(kernel_shape[1]);
108+
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]);
109+
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]);
110+
111+
TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};
112+
Tensor nhwc_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);
113+
OIHW2OHWIProgram transpose_program{};
114+
transpose_program.SetWorkgroupSize(64);
115+
116+
const uint32_t Ci_tiles = ceil_div(channel_input, 64u);
117+
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);
118+
119+
transpose_program.AddInput({kernel,
120+
ProgramTensorMetadataDependency::TypeAndRank});
121+
transpose_program.AddOutput({&nhwc_kernel,
122+
ProgramTensorMetadataDependency::TypeAndRank});
123+
transpose_program.AddUniformVariables({{channel_output},
124+
{channel_input},
125+
{kernel_height},
126+
{kernel_width},
127+
{Ci_tiles},
128+
{ceil_div(kernel_height * kernel_height, 4u)}});
129+
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));
130+
131+
// im2col-matmul
132+
const TensorShape input_shape = input->Shape();
133+
const TensorShape output_shape = output->Shape();
134+
135+
const uint32_t batch = onnxruntime::narrow<uint32_t>(input_shape[0]);
136+
const uint32_t input_height = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 1 : 2]);
137+
const uint32_t input_width = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 2 : 3]);
138+
const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]);
139+
const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]);
140+
141+
const uint32_t im2col_m = output_height * output_width;
142+
const uint32_t im2col_k = kernel_height * kernel_width * channel_input;
143+
const uint32_t im2col_n = channel_output;
144+
145+
const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n);
146+
const uint32_t workgroup_size = tile_n;
147+
const bool use_subgroup = true;
148+
Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup};
149+
im2col_mm_program.SetWorkgroupSize(workgroup_size);
150+
151+
const uint32_t M_tiles = ceil_div(im2col_m, tile_m);
152+
const uint32_t N_tiles = ceil_div(im2col_n, tile_n);
153+
im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch);
154+
155+
im2col_mm_program.AddInput({input,
156+
ProgramTensorMetadataDependency::TypeAndRank,
157+
4});
158+
im2col_mm_program.AddInput({&nhwc_kernel,
159+
ProgramTensorMetadataDependency::TypeAndRank,
160+
4});
161+
if (has_bias) {
162+
im2col_mm_program.AddInput({bias,
163+
ProgramTensorMetadataDependency::TypeAndRank});
164+
}
165+
im2col_mm_program.AddOutput({output,
166+
ProgramTensorMetadataDependency::TypeAndRank});
167+
im2col_mm_program.AddUniformVariables({{batch},
168+
{input_height},
169+
{input_width},
170+
{channel_input},
171+
{kernel_height},
172+
{kernel_width},
173+
{output_height},
174+
{output_width},
175+
{im2col_m},
176+
{im2col_k},
177+
{im2col_n},
178+
{M_tiles},
179+
{N_tiles},
180+
{ceil_div(ceil_div(im2col_k, 4u), 4u)},
181+
{dilations},
182+
{pads},
183+
{strides}});
184+
im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup);
185+
186+
return context.RunProgram(im2col_mm_program);
187+
}
188+
189+
bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
190+
const bool is_channels_last,
191+
const ActivationKind activation_kind,
192+
const TensorShape kernel_shape,
193+
const AutoPadType auto_pad,
194+
const uint32_t group) {
195+
if (!IsDeviceSupported(context)) {
196+
return false;
197+
}
198+
199+
// TODO: Support !is_channels_last
200+
// TODO: Support fuse
201+
// TODO: Support auto pad
202+
// TODO: Support group conv
203+
if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) {
204+
return false;
205+
}
206+
207+
// TODO: Support conv2d_1x1
208+
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]);
209+
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]);
210+
if (kernel_height == 1 && kernel_width == 1) {
211+
return false;
212+
}
213+
214+
// TODO: Support channel input vec1
215+
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(kernel_shape[1]);
216+
if (channel_input % 4 != 0) {
217+
return false;
218+
}
219+
220+
return true;
221+
}
222+
223+
} // namespace webgpu
224+
} // namespace onnxruntime
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <vector>
6+
#include <string>
7+
8+
#include "core/common/inlined_containers.h"
9+
#include "core/framework/tensor_shape.h"
10+
#include "core/framework/tensor.h"
11+
#include "core/providers/cpu/nn/conv_attributes.h"
12+
#include "core/providers/webgpu/program.h"
13+
#include "core/providers/webgpu/webgpu_supported_types.h"
14+
#include "core/providers/webgpu/shader_helper.h"
15+
#include "core/providers/webgpu/webgpu_kernel.h"
16+
#include "core/providers/webgpu/nn/fuse_utils.h"
17+
18+
namespace onnxruntime {
19+
namespace webgpu {
20+
21+
// Transpose OIHW Weight to OHWI
22+
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
23+
public:
24+
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
25+
26+
Status GenerateShaderCode(ShaderHelper& shader) const override;
27+
28+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
29+
{"O", ProgramUniformVariableDataType::Uint32},
30+
{"I", ProgramUniformVariableDataType::Uint32},
31+
{"H", ProgramUniformVariableDataType::Uint32},
32+
{"W", ProgramUniformVariableDataType::Uint32},
33+
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
34+
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
35+
};
36+
37+
class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
38+
public:
39+
Im2ColMatMulProgram(bool has_bias,
40+
uint32_t tile_m,
41+
uint32_t tile_n,
42+
bool use_subgroup) : Program("Im2ColMatMul"),
43+
has_bias_(has_bias),
44+
tile_m_(tile_m),
45+
tile_n_(tile_n),
46+
use_subgroup_(use_subgroup) {}
47+
48+
Status GenerateShaderCode(ShaderHelper& shader) const override;
49+
50+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
51+
{"batch", ProgramUniformVariableDataType::Uint32},
52+
{"src_h", ProgramUniformVariableDataType::Uint32},
53+
{"src_w", ProgramUniformVariableDataType::Uint32},
54+
{"channel_i", ProgramUniformVariableDataType::Uint32},
55+
{"kernel_h", ProgramUniformVariableDataType::Uint32},
56+
{"kernel_w", ProgramUniformVariableDataType::Uint32},
57+
{"output_h", ProgramUniformVariableDataType::Uint32},
58+
{"output_w", ProgramUniformVariableDataType::Uint32},
59+
{"im2col_m", ProgramUniformVariableDataType::Uint32},
60+
{"im2col_k", ProgramUniformVariableDataType::Uint32},
61+
{"im2col_n", ProgramUniformVariableDataType::Uint32},
62+
{"M_tiles", ProgramUniformVariableDataType::Uint32},
63+
{"N_tiles", ProgramUniformVariableDataType::Uint32},
64+
{"K_tiles", ProgramUniformVariableDataType::Uint32},
65+
{"dilations", ProgramUniformVariableDataType::Uint32},
66+
{"pads", ProgramUniformVariableDataType::Uint32},
67+
{"strides", ProgramUniformVariableDataType::Uint32});
68+
69+
private:
70+
bool has_bias_;
71+
72+
uint32_t tile_m_;
73+
uint32_t tile_n_;
74+
bool use_subgroup_;
75+
};
76+
77+
bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
78+
const bool is_channels_last,
79+
const ActivationKind activation_kind,
80+
const TensorShape kernel_shape,
81+
const AutoPadType auto_pad,
82+
const uint32_t group);
83+
84+
Status ApplyIm2ColMatMulProgram(ComputeContext& context,
85+
const bool is_channels_last,
86+
const std::vector<uint32_t>& dilations,
87+
const std::vector<uint32_t>& pads,
88+
const std::vector<uint32_t>& strides,
89+
Tensor* output);
90+
91+
} // namespace webgpu
92+
} // namespace onnxruntime

0 commit comments

Comments
 (0)