Skip to content

Commit 6b4f9c4

Browse files
[WebGPU EP] Batch Norm Implementation (#23525)
Increases operator coverage for webgpu ep. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 1fce51b commit 6b4f9c4

File tree

4 files changed

+206
-11
lines changed

4 files changed

+206
-11
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/inlined_containers.h"
5+
#include "core/providers/webgpu/nn/batch_norm.h"
6+
#include "core/providers/cpu/nn/batch_norm_helper.h"
7+
#include "core/providers/cpu/tensor/utils.h"
8+
#include "core/providers/webgpu/shader_helper.h"
9+
#include "core/providers/webgpu/webgpu_supported_types.h"
10+
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
#define WEBGPU_BATCH_NORM_VERSIONED_KERNEL(start, end, domain, is_nhwc) \
15+
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
16+
BatchNormalization, \
17+
domain, \
18+
start, \
19+
end, \
20+
kWebGpuExecutionProvider, \
21+
(*KernelDefBuilder::Create()) \
22+
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
23+
BatchNormalization<is_nhwc>);
24+
25+
#define WEBGPU_BATCH_NORM_KERNEL(version, domain, is_nhwc) \
26+
ONNX_OPERATOR_KERNEL_EX( \
27+
BatchNormalization, \
28+
domain, \
29+
version, \
30+
kWebGpuExecutionProvider, \
31+
(*KernelDefBuilder::Create()) \
32+
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
33+
BatchNormalization<is_nhwc>);
34+
35+
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kOnnxDomain, false)
36+
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kOnnxDomain, false)
37+
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kOnnxDomain, false)
38+
WEBGPU_BATCH_NORM_KERNEL(15, kOnnxDomain, false)
39+
40+
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kMSInternalNHWCDomain, true)
41+
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kMSInternalNHWCDomain, true)
42+
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kMSInternalNHWCDomain, true)
43+
WEBGPU_BATCH_NORM_KERNEL(15, kMSInternalNHWCDomain, true)
44+
45+
Status BatchNormalizationProgram::GenerateShaderCode(ShaderHelper& shader) const {
46+
const ShaderVariableHelper& input_tensor = shader.AddInput("input_tensor", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
47+
const ShaderVariableHelper& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
48+
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
49+
const ShaderVariableHelper& input_mean = shader.AddInput("input_mean", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
50+
const ShaderVariableHelper& input_var = shader.AddInput("input_var", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
51+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
52+
53+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
54+
<< " let idx = global_idx * " << components_ << ";\n"
55+
<< " var outputIndices = " << output.OffsetToIndices("idx") << ";\n";
56+
if (spatial_) {
57+
if (input_tensor.Rank() == 1) {
58+
shader.MainFunctionBody() << " let cOffset = 0u;\n";
59+
} else {
60+
if (format_ == DataLayout::NHWC) {
61+
shader.MainFunctionBody() << " let cOffset = outputIndices[" << input_tensor.Rank() - 1 << "] / " << components_ << ";\n";
62+
} else {
63+
shader.MainFunctionBody() << " let cOffset = outputIndices[1];\n";
64+
}
65+
}
66+
} else {
67+
if (format_ == DataLayout::NCHW) {
68+
shader.MainFunctionBody() << " " << output.IndicesSet("outputIndices", "0", "0") << "\n"
69+
<< " let cOffset = " << output.IndicesToOffset("outputIndices") << ";\n";
70+
} else {
71+
// update C channel
72+
shader.MainFunctionBody() << " var cIndices = scale_indices_t(0);\n"
73+
<< " cIndices[0] = outputIndices[" << input_tensor.Rank() - 1 << "];\n";
74+
// update D1 x ... x Dn channels
75+
for (int i = 1; i < scale.Rank(); i++) {
76+
shader.MainFunctionBody() << " cIndices[" << i << "] = outputIndices[" << i << "];\n";
77+
}
78+
shader.MainFunctionBody() << " let cOffset = " << scale.IndicesToOffset("cIndices") << ";\n";
79+
}
80+
}
81+
82+
shader.MainFunctionBody() << " let scale = " << scale.GetByOffset("cOffset") << ";\n"
83+
<< " let B = " << B.GetByOffset("cOffset") << ";\n"
84+
<< " let input_mean = " << input_mean.GetByOffset("cOffset") << ";\n"
85+
<< " let input_var = " << input_var.GetByOffset("cOffset") << ";\n"
86+
<< " let x = " << input_tensor.GetByOffset("global_idx") << ";\n"
87+
<< " let value = (x - input_mean) * inverseSqrt(input_var + " << epsilon_ << ") * scale + B;\n"
88+
<< " " << output.SetByOffset("global_idx", "value") << "\n";
89+
90+
return Status::OK();
91+
}
92+
93+
template <bool is_nhwc>
94+
Status BatchNormalization<is_nhwc>::ComputeInternal(ComputeContext& context) const {
95+
if (training_mode_) {
96+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization trainingMode is not supported yet.");
97+
}
98+
99+
if (context.InputCount() != 5) {
100+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization requires 5 inputs.");
101+
}
102+
103+
const auto* input_tensor = context.Input(0);
104+
const TensorShape& input_shape = input_tensor->Shape();
105+
size_t input_rank = input_shape.NumDimensions();
106+
const int components = spatial_ ? ((input_shape[input_rank - 1] % 4 == 0) ? 4 : ((input_shape[input_rank - 1] % 2 == 0) ? 2 : 1)) : 1;
107+
108+
auto output_dims = input_shape.AsShapeVector();
109+
TensorShape output_shape(output_dims);
110+
auto* output_tensor = context.Output(0, output_shape);
111+
int64_t output_size = output_tensor->Shape().Size() / static_cast<int64_t>(components);
112+
113+
if (output_size == 0) {
114+
return Status::OK();
115+
}
116+
117+
const auto* scale = context.Input<Tensor>(1);
118+
const auto* B = context.Input<Tensor>(2);
119+
const auto* input_mean = context.Input<Tensor>(3);
120+
const auto* input_var = context.Input<Tensor>(4);
121+
122+
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(input_tensor, scale, B, input_mean, input_var, spatial_ == 1, format_ == DataLayout::NHWC));
123+
124+
BatchNormalizationProgram program{epsilon_, spatial_, format_, static_cast<int64_t>(components)};
125+
program
126+
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
127+
{scale, ProgramTensorMetadataDependency::TypeAndRank},
128+
{B, ProgramTensorMetadataDependency::TypeAndRank},
129+
{input_mean, ProgramTensorMetadataDependency::TypeAndRank},
130+
{input_var, ProgramTensorMetadataDependency::TypeAndRank}})
131+
.AddOutputs({output_tensor})
132+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
133+
.AddUniformVariables({{static_cast<uint32_t>(output_size)}});
134+
return context.RunProgram(program);
135+
}
136+
137+
} // namespace webgpu
138+
} // namespace onnxruntime
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
class BatchNormalizationProgram final : public Program<BatchNormalizationProgram> {
13+
public:
14+
BatchNormalizationProgram(float epsilon, int64_t spatial, DataLayout format, int64_t components) : Program{"BatchNormalization"},
15+
epsilon_{epsilon},
16+
spatial_{spatial},
17+
format_{format},
18+
components_{components} {}
19+
20+
Status GenerateShaderCode(ShaderHelper& sh) const override;
21+
22+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
23+
24+
private:
25+
float epsilon_;
26+
int64_t spatial_;
27+
DataLayout format_;
28+
int64_t components_;
29+
};
30+
31+
template <bool is_nhwc>
32+
class BatchNormalization final : public WebGpuKernel {
33+
public:
34+
BatchNormalization(const OpKernelInfo& info) : WebGpuKernel(info) {
35+
epsilon_ = info.GetAttrOrDefault<float>("epsilon", 1e-5f);
36+
momentum_ = info.GetAttrOrDefault<float>("momentum", 0.9f);
37+
spatial_ = info.GetAttrOrDefault<int64_t>("spatial", 1);
38+
training_mode_ = info.GetAttrOrDefault<int64_t>("training_mode", 0);
39+
// NCHW for ai.onnx domain, NHWC for com.ms.internal.nhwc domain
40+
format_ = is_nhwc ? DataLayout::NHWC : DataLayout::NCHW;
41+
}
42+
43+
Status ComputeInternal(ComputeContext& context) const override;
44+
45+
private:
46+
float epsilon_;
47+
float momentum_;
48+
int64_t spatial_;
49+
int64_t training_mode_;
50+
DataLayout format_;
51+
};
52+
53+
} // namespace webgpu
54+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -696,14 +696,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
696696
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If)>,
697697
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If)>,
698698

699-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
700-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
701-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
702-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
703-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
704-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
705-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
706-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
699+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
700+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
701+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
702+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
703+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
704+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
705+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
706+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
707707
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
708708
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum)>,
709709
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,

onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,8 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
924924
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
925925
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
926926
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
927-
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
927+
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
928+
kWebGpuExecutionProvider});
928929
}
929930

930931
TEST(BatchNormTest, ForwardTrainingTestOpset14) {
@@ -953,7 +954,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
953954
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
954955
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
955956
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
956-
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
957+
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
958+
kWebGpuExecutionProvider});
957959
}
958960

959961
TEST(BatchNormTest, ForwardTrainingTestOpset15) {
@@ -982,7 +984,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
982984
// Same exclusions as the opset 14 test
983985
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
984986
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
985-
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
987+
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
988+
kWebGpuExecutionProvider});
986989
}
987990
#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT
988991

0 commit comments

Comments
 (0)