11// Copyright (c) Microsoft Corporation. All rights reserved.
22// Licensed under the MIT License.
33
4+ #include " core/common/span_utils.h"
45#include " core/common/inlined_containers.h"
56#include " core/providers/cpu/tensor/utils.h"
67#include " core/providers/webgpu/tensor/transpose.h"
910#include " core/providers/webgpu/webgpu_supported_types.h"
1011#include " core/providers/webgpu/webgpu_utils.h"
1112
13+ namespace {
14+ bool AreSpansEqual (gsl::span<const size_t > a, gsl::span<const size_t > b) {
15+ if (a.size () != b.size ()) {
16+ return false ;
17+ }
18+
19+ return std::equal (a.begin (), a.end (), b.begin ());
20+ }
21+
22+ auto SqueezeShape (const gsl::span<const int64_t >& shape,
23+ const gsl::span<const size_t >& adjusted_perm,
24+ onnxruntime::TensorShapeVector& new_shape,
25+ onnxruntime::TensorShapeVector& new_perm) {
26+ for (size_t i = 0 ; i < shape.size (); ++i) {
27+ if (shape[i] != 1 ) {
28+ new_shape.push_back (shape[i]);
29+ }
30+ if (shape[adjusted_perm[i]] != 1 ) {
31+ new_perm.push_back (adjusted_perm[i]);
32+ }
33+ }
34+ };
35+ } // namespace
36+
1237namespace onnxruntime {
1338namespace webgpu {
1439ONNX_OPERATOR_VERSIONED_KERNEL_EX (
@@ -47,19 +72,14 @@ ONNX_OPERATOR_KERNEL_EX(
4772 .TypeConstraint(" T" , WebGpuSupportedNumberTypes()),
4873 Transpose);
4974
50- auto SqueezeShape (const gsl::span<const int64_t >& shape,
51- const gsl::span<const size_t >& adjusted_perm,
52- TensorShapeVector& new_shape,
53- TensorShapeVector& new_perm) {
54- for (size_t i = 0 ; i < shape.size (); ++i) {
55- if (shape[i] != 1 ) {
56- new_shape.push_back (shape[i]);
57- }
58- if (shape[adjusted_perm[i]] != 1 ) {
59- new_perm.push_back (adjusted_perm[i]);
60- }
61- }
62- };
75+ Status OIHW2OHWIProgram::GenerateShaderCode (ShaderHelper& shader) const {
76+ const auto & src = shader.AddInput (" src" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
77+ const auto & output = shader.AddOutput (" output" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
78+
79+ return WGSL_TEMPLATE_APPLY (shader, " tensor/oihw_to_ohwi.wgsl.template" ,
80+ WGSL_TEMPLATE_VARIABLE (output, output),
81+ WGSL_TEMPLATE_VARIABLE (src, src));
82+ }
6383
6484Status TransposeProgram::GenerateShaderCode (ShaderHelper& shader) const {
6585 const auto & input = shader.AddInput (" a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
@@ -106,12 +126,52 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
106126 const auto & input_shape = input.Shape ();
107127 const auto & input_dims = input_shape.GetDims ();
108128 int32_t rank = static_cast <int32_t >(input_shape.NumDimensions ());
109-
110129 TensorShapeVector output_dims (rank);
111130
112131 for (int32_t i = 0 ; i < rank; i++) {
113132 output_dims[i] = input_dims[permutations[i]];
114133 }
134+ TensorShape output_shape (output_dims);
135+
136+ // Check if `OIHW2OHWIProgram` can be applied.
137+ //
138+ // `OIHW2OHWIProgram` was originally designed to transpose 4D weights from OIHW
139+ // to OHWI format, utilizing workgroup tiling to maximize bandwidth through
140+ // coalesced reads and writes. While variable names reflect this origin for
141+ // simplicity, the shader is now generalized for broader use, supporting any
142+ // permutation equivalent to {0, 2, 3, 1}.
143+ //
144+ // TODO: Extend support to 2D and 3D transpositions.
145+ if (AreSpansEqual (permutations, AsSpan<const size_t >({0 , 2 , 3 , 1 }))) {
146+ const uint32_t channel_output = onnxruntime::narrow<uint32_t >(input_shape[0 ]);
147+ const uint32_t channel_input = onnxruntime::narrow<uint32_t >(input_shape[1 ]);
148+ const uint32_t kernel_height = onnxruntime::narrow<uint32_t >(input_shape[2 ]);
149+ const uint32_t kernel_width = onnxruntime::narrow<uint32_t >(input_shape[3 ]);
150+
151+ // Calculate tiling for the input channel dimension (tiled by 64)
152+ const uint32_t input_channel_tiles = CeilDiv (channel_input, 64u );
153+ const uint32_t dispatch_size = channel_output * input_channel_tiles;
154+
155+ // Threshold check: Only apply if the workload is large enough to saturate
156+ // GPU compute units. For small tensors, the overhead of the transpose
157+ // outweighs the gain.
158+ if (dispatch_size >= 128u ) {
159+ OIHW2OHWIProgram transpose_program{};
160+ transpose_program.SetWorkgroupSize (64 );
161+ transpose_program.SetDispatchGroupSize (dispatch_size);
162+ transpose_program.AddInput ({&input,
163+ ProgramTensorMetadataDependency::TypeAndRank});
164+ transpose_program.AddOutput ({&output,
165+ ProgramTensorMetadataDependency::TypeAndRank});
166+ transpose_program.AddUniformVariables ({{channel_output},
167+ {channel_input},
168+ {kernel_height},
169+ {kernel_width},
170+ {input_channel_tiles},
171+ {CeilDiv (kernel_height * kernel_width, 4u )}});
172+ return context.RunProgram (transpose_program);
173+ }
174+ }
115175
116176 TensorShapeVector new_shape{};
117177 TensorShapeVector new_perm{};
@@ -120,15 +180,14 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
120180 const bool channels_first = new_perm == TensorShapeVector ({3 , 1 , 2 });
121181 const bool use_shared = (new_shape.size () == 2 && new_perm[0 ] > new_perm[1 ]) || channels_last || channels_first;
122182 auto new_input_shape = input_shape;
123- TensorShape new_output_shape (output_dims);
124183
125184 if (use_shared) {
126185 new_input_shape = channels_last
127186 ? TensorShape ({new_shape[0 ], new_shape[1 ] * new_shape[2 ]})
128187 : channels_first
129188 ? TensorShape ({new_shape[0 ] * new_shape[1 ], new_shape[2 ]})
130189 : new_shape;
131- new_output_shape = TensorShape ({new_input_shape[1 ], new_input_shape[0 ]});
190+ output_shape = TensorShape ({new_input_shape[1 ], new_input_shape[0 ]});
132191 }
133192
134193 uint32_t output_size = onnxruntime::narrow<uint32_t >(input_shape.Size ());
@@ -137,13 +196,13 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
137196 program
138197 .CacheHint (absl::StrJoin (permutations, " -" ))
139198 .AddInputs ({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1 }})
140- .AddOutputs ({{&output, ProgramTensorMetadataDependency::None, new_output_shape , 1 }})
199+ .AddOutputs ({{&output, ProgramTensorMetadataDependency::None, output_shape , 1 }})
141200 .AddUniformVariables ({{output_size}});
142201
143202 if (use_shared) {
144203 program.SetWorkgroupSize (TILE_SIZE, TILE_SIZE, 1 );
145- program.SetDispatchGroupSize (static_cast <uint32_t >((new_output_shape [1 ] + TILE_SIZE - 1 ) / TILE_SIZE),
146- static_cast <uint32_t >(((new_output_shape [0 ] + TILE_SIZE - 1 ) / TILE_SIZE)));
204+ program.SetDispatchGroupSize (static_cast <uint32_t >((output_shape [1 ] + TILE_SIZE - 1 ) / TILE_SIZE),
205+ static_cast <uint32_t >(((output_shape [0 ] + TILE_SIZE - 1 ) / TILE_SIZE)));
147206 } else {
148207 program.SetWorkgroupSize (64u );
149208
0 commit comments