Skip to content

Commit 2aaf21b

Browse files
authored
[webgpu] Optimize generic 4D Transpose using OIHW2OHWI Program (#26942)
### Description This PR migrates the `OIHW2OHWI` Program from `Im2ColMatMul` to the `Transpose` operator. By centralizing this logic, we leverage the specialized shader to optimize generic 4D transpositions (specifically the {0, 2, 3, 1} permutation pattern) while reducing code duplication. While this shader is capable of supporting 2D/3D transpositions, those optimizations are reserved for follow-up PRs. ### Motivation and Context See above.
1 parent 62e8ba9 commit 2aaf21b

File tree

5 files changed

+99
-66
lines changed

5 files changed

+99
-66
lines changed

onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "core/providers/webgpu/webgpu_utils.h"
88
#include "core/providers/webgpu/nn/im2col_matmul.h"
9+
#include "core/providers/webgpu/nn/conv.h"
910
#include "core/providers/webgpu/nn/activation_util.h"
1011

1112
namespace onnxruntime {
@@ -52,15 +53,6 @@ bool IsDeviceSupported(const ComputeContextBase& context) {
5253

5354
} // namespace
5455

55-
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
56-
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
57-
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
58-
59-
return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
60-
WGSL_TEMPLATE_VARIABLE(output, output),
61-
WGSL_TEMPLATE_VARIABLE(src, src));
62-
}
63-
6456
Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
6557
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
6658
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
@@ -93,34 +85,16 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context,
9385
const bool has_bias = context.InputCount() > 2;
9486
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
9587

96-
// Transpose OIHW Weight to OHWI
97-
// TODO: Move to `Transpose`
98-
// TODO: Use prepack
9988
TensorShape weight_shape = weight->Shape();
10089
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]);
10190
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
10291
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
10392
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);
10493

105-
TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input};
106-
Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape);
107-
OIHW2OHWIProgram transpose_program{};
108-
transpose_program.SetWorkgroupSize(64);
109-
110-
const uint32_t Ci_tiles = CeilDiv(channel_input, 64u);
111-
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);
112-
113-
transpose_program.AddInput({weight,
114-
ProgramTensorMetadataDependency::TypeAndRank});
115-
transpose_program.AddOutput({&ohwi_weight,
116-
ProgramTensorMetadataDependency::TypeAndRank});
117-
transpose_program.AddUniformVariables({{channel_output},
118-
{channel_input},
119-
{kernel_height},
120-
{kernel_width},
121-
{Ci_tiles},
122-
{CeilDiv(kernel_height * kernel_height, 4u)}});
123-
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));
94+
// Transpose OIHW Weight to OHWI
95+
// TODO: Use prepack
96+
Tensor ohwi_weight;
97+
ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, weight->Shape(), &ohwi_weight, {0, 2, 3, 1}));
12498

12599
// im2col-matmul
126100
const TensorShape src_shape = src->Shape();

onnxruntime/core/providers/webgpu/nn/im2col_matmul.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,6 @@
1818
namespace onnxruntime {
1919
namespace webgpu {
2020

21-
// Transpose OIHW Weight to OHWI
22-
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
23-
public:
24-
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
25-
26-
Status GenerateShaderCode(ShaderHelper& shader) const override;
27-
28-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
29-
{"O", ProgramUniformVariableDataType::Uint32},
30-
{"I", ProgramUniformVariableDataType::Uint32},
31-
{"H", ProgramUniformVariableDataType::Uint32},
32-
{"W", ProgramUniformVariableDataType::Uint32},
33-
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
34-
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
35-
};
36-
3721
class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
3822
public:
3923
Im2ColMatMulProgram(bool has_bias,

onnxruntime/core/providers/webgpu/nn/oihw_to_ohwi.wgsl.template renamed to onnxruntime/core/providers/webgpu/tensor/oihw_to_ohwi.wgsl.template

File renamed without changes.

onnxruntime/core/providers/webgpu/tensor/transpose.cc

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include "core/common/span_utils.h"
45
#include "core/common/inlined_containers.h"
56
#include "core/providers/cpu/tensor/utils.h"
67
#include "core/providers/webgpu/tensor/transpose.h"
@@ -9,6 +10,30 @@
910
#include "core/providers/webgpu/webgpu_supported_types.h"
1011
#include "core/providers/webgpu/webgpu_utils.h"
1112

13+
namespace {
14+
bool AreSpansEqual(gsl::span<const size_t> a, gsl::span<const size_t> b) {
15+
if (a.size() != b.size()) {
16+
return false;
17+
}
18+
19+
return std::equal(a.begin(), a.end(), b.begin());
20+
}
21+
22+
auto SqueezeShape(const gsl::span<const int64_t>& shape,
23+
const gsl::span<const size_t>& adjusted_perm,
24+
onnxruntime::TensorShapeVector& new_shape,
25+
onnxruntime::TensorShapeVector& new_perm) {
26+
for (size_t i = 0; i < shape.size(); ++i) {
27+
if (shape[i] != 1) {
28+
new_shape.push_back(shape[i]);
29+
}
30+
if (shape[adjusted_perm[i]] != 1) {
31+
new_perm.push_back(adjusted_perm[i]);
32+
}
33+
}
34+
};
35+
} // namespace
36+
1237
namespace onnxruntime {
1338
namespace webgpu {
1439
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@@ -47,19 +72,14 @@ ONNX_OPERATOR_KERNEL_EX(
4772
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
4873
Transpose);
4974

50-
auto SqueezeShape(const gsl::span<const int64_t>& shape,
51-
const gsl::span<const size_t>& adjusted_perm,
52-
TensorShapeVector& new_shape,
53-
TensorShapeVector& new_perm) {
54-
for (size_t i = 0; i < shape.size(); ++i) {
55-
if (shape[i] != 1) {
56-
new_shape.push_back(shape[i]);
57-
}
58-
if (shape[adjusted_perm[i]] != 1) {
59-
new_perm.push_back(adjusted_perm[i]);
60-
}
61-
}
62-
};
75+
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
76+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
77+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
78+
79+
return WGSL_TEMPLATE_APPLY(shader, "tensor/oihw_to_ohwi.wgsl.template",
80+
WGSL_TEMPLATE_VARIABLE(output, output),
81+
WGSL_TEMPLATE_VARIABLE(src, src));
82+
}
6383

6484
Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
6585
const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
@@ -106,12 +126,52 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
106126
const auto& input_shape = input.Shape();
107127
const auto& input_dims = input_shape.GetDims();
108128
int32_t rank = static_cast<int32_t>(input_shape.NumDimensions());
109-
110129
TensorShapeVector output_dims(rank);
111130

112131
for (int32_t i = 0; i < rank; i++) {
113132
output_dims[i] = input_dims[permutations[i]];
114133
}
134+
TensorShape output_shape(output_dims);
135+
136+
// Check if `OIHW2OHWIProgram` can be applied.
137+
//
138+
// `OIHW2OHWIProgram` was originally designed to transpose 4D weights from OIHW
139+
// to OHWI format, utilizing workgroup tiling to maximize bandwidth through
140+
// coalesced reads and writes. While variable names reflect this origin for
141+
// simplicity, the shader is now generalized for broader use, supporting any
142+
// permutation equivalent to {0, 2, 3, 1}.
143+
//
144+
// TODO: Extend support to 2D and 3D transpositions.
145+
if (AreSpansEqual(permutations, AsSpan<const size_t>({0, 2, 3, 1}))) {
146+
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(input_shape[0]);
147+
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(input_shape[1]);
148+
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(input_shape[2]);
149+
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(input_shape[3]);
150+
151+
// Calculate tiling for the input channel dimension (tiled by 64)
152+
const uint32_t input_channel_tiles = CeilDiv(channel_input, 64u);
153+
const uint32_t dispatch_size = channel_output * input_channel_tiles;
154+
155+
// Threshold check: Only apply if the workload is large enough to saturate
156+
// GPU compute units. For small tensors, the overhead of the transpose
157+
// outweighs the gain.
158+
if (dispatch_size >= 128u) {
159+
OIHW2OHWIProgram transpose_program{};
160+
transpose_program.SetWorkgroupSize(64);
161+
transpose_program.SetDispatchGroupSize(dispatch_size);
162+
transpose_program.AddInput({&input,
163+
ProgramTensorMetadataDependency::TypeAndRank});
164+
transpose_program.AddOutput({&output,
165+
ProgramTensorMetadataDependency::TypeAndRank});
166+
transpose_program.AddUniformVariables({{channel_output},
167+
{channel_input},
168+
{kernel_height},
169+
{kernel_width},
170+
{input_channel_tiles},
171+
{CeilDiv(kernel_height * kernel_width, 4u)}});
172+
return context.RunProgram(transpose_program);
173+
}
174+
}
115175

116176
TensorShapeVector new_shape{};
117177
TensorShapeVector new_perm{};
@@ -120,15 +180,14 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
120180
const bool channels_first = new_perm == TensorShapeVector({3, 1, 2});
121181
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
122182
auto new_input_shape = input_shape;
123-
TensorShape new_output_shape(output_dims);
124183

125184
if (use_shared) {
126185
new_input_shape = channels_last
127186
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
128187
: channels_first
129188
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
130189
: new_shape;
131-
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
190+
output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
132191
}
133192

134193
uint32_t output_size = onnxruntime::narrow<uint32_t>(input_shape.Size());
@@ -137,13 +196,13 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
137196
program
138197
.CacheHint(absl::StrJoin(permutations, "-"))
139198
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
140-
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
199+
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, output_shape, 1}})
141200
.AddUniformVariables({{output_size}});
142201

143202
if (use_shared) {
144203
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
145-
program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
146-
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
204+
program.SetDispatchGroupSize(static_cast<uint32_t>((output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
205+
static_cast<uint32_t>(((output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)));
147206
} else {
148207
program.SetWorkgroupSize(64u);
149208

onnxruntime/core/providers/webgpu/tensor/transpose.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
namespace onnxruntime {
1212
namespace webgpu {
1313

14+
// Transpose OIHW Weight to OHWI
15+
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
16+
public:
17+
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
18+
19+
Status GenerateShaderCode(ShaderHelper& shader) const override;
20+
21+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
22+
{"O", ProgramUniformVariableDataType::Uint32},
23+
{"I", ProgramUniformVariableDataType::Uint32},
24+
{"H", ProgramUniformVariableDataType::Uint32},
25+
{"W", ProgramUniformVariableDataType::Uint32},
26+
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
27+
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
28+
};
29+
1430
class Transpose final : public WebGpuKernel, public TransposeBase {
1531
public:
1632
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {

0 commit comments

Comments
 (0)