1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/common/inlined_containers.h"
5+ #include " core/providers/webgpu/nn/batch_norm.h"
6+ #include " core/providers/cpu/nn/batch_norm_helper.h"
7+ #include " core/providers/cpu/tensor/utils.h"
8+ #include " core/providers/webgpu/shader_helper.h"
9+ #include " core/providers/webgpu/webgpu_supported_types.h"
10+
11+ namespace onnxruntime {
12+ namespace webgpu {
13+
14+ #define WEBGPU_BATCH_NORM_VERSIONED_KERNEL (start, end, domain, is_nhwc ) \
15+ ONNX_OPERATOR_VERSIONED_KERNEL_EX ( \
16+ BatchNormalization, \
17+ domain, \
18+ start, \
19+ end, \
20+ kWebGpuExecutionProvider , \
21+ (*KernelDefBuilder::Create ()) \
22+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes()), \
23+ BatchNormalization<is_nhwc>);
24+
25+ #define WEBGPU_BATCH_NORM_KERNEL (version, domain, is_nhwc ) \
26+ ONNX_OPERATOR_KERNEL_EX ( \
27+ BatchNormalization, \
28+ domain, \
29+ version, \
30+ kWebGpuExecutionProvider , \
31+ (*KernelDefBuilder::Create ()) \
32+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes()), \
33+ BatchNormalization<is_nhwc>);
34+
35+ WEBGPU_BATCH_NORM_VERSIONED_KERNEL (7 , 8 , kOnnxDomain , false )
36+ WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9 , 13 , kOnnxDomain , false )
37+ WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14 , 14 , kOnnxDomain , false )
38+ WEBGPU_BATCH_NORM_KERNEL(15 , kOnnxDomain , false )
39+
40+ WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7 , 8 , kMSInternalNHWCDomain , true )
41+ WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9 , 13 , kMSInternalNHWCDomain , true )
42+ WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14 , 14 , kMSInternalNHWCDomain , true )
43+ WEBGPU_BATCH_NORM_KERNEL(15 , kMSInternalNHWCDomain , true )
44+
45+ Status BatchNormalizationProgram::GenerateShaderCode(ShaderHelper& shader) const {
46+ const ShaderVariableHelper& input_tensor = shader.AddInput (" input_tensor" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
47+ const ShaderVariableHelper& scale = shader.AddInput (" scale" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
48+ const ShaderVariableHelper& B = shader.AddInput (" B" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
49+ const ShaderVariableHelper& input_mean = shader.AddInput (" input_mean" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
50+ const ShaderVariableHelper& input_var = shader.AddInput (" input_var" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
51+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
52+
53+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
54+ << " let idx = global_idx * " << components_ << " ;\n "
55+ << " var outputIndices = " << output.OffsetToIndices (" idx" ) << " ;\n " ;
56+ if (spatial_) {
57+ if (input_tensor.Rank () == 1 ) {
58+ shader.MainFunctionBody () << " let cOffset = 0u;\n " ;
59+ } else {
60+ if (format_ == DataLayout::NHWC) {
61+ shader.MainFunctionBody () << " let cOffset = outputIndices[" << input_tensor.Rank () - 1 << " ] / " << components_ << " ;\n " ;
62+ } else {
63+ shader.MainFunctionBody () << " let cOffset = outputIndices[1];\n " ;
64+ }
65+ }
66+ } else {
67+ if (format_ == DataLayout::NCHW) {
68+ shader.MainFunctionBody () << " " << output.IndicesSet (" outputIndices" , " 0" , " 0" ) << " \n "
69+ << " let cOffset = " << output.IndicesToOffset (" outputIndices" ) << " ;\n " ;
70+ } else {
71+ // update C channel
72+ shader.MainFunctionBody () << " var cIndices = scale_indices_t(0);\n "
73+ << " cIndices[0] = outputIndices[" << input_tensor.Rank () - 1 << " ];\n " ;
74+ // update D1 x ... x Dn channels
75+ for (int i = 1 ; i < scale.Rank (); i++) {
76+ shader.MainFunctionBody () << " cIndices[" << i << " ] = outputIndices[" << i << " ];\n " ;
77+ }
78+ shader.MainFunctionBody () << " let cOffset = " << scale.IndicesToOffset (" cIndices" ) << " ;\n " ;
79+ }
80+ }
81+
82+ shader.MainFunctionBody () << " let scale = " << scale.GetByOffset (" cOffset" ) << " ;\n "
83+ << " let B = " << B.GetByOffset (" cOffset" ) << " ;\n "
84+ << " let input_mean = " << input_mean.GetByOffset (" cOffset" ) << " ;\n "
85+ << " let input_var = " << input_var.GetByOffset (" cOffset" ) << " ;\n "
86+ << " let x = " << input_tensor.GetByOffset (" global_idx" ) << " ;\n "
87+ << " let value = (x - input_mean) * inverseSqrt(input_var + " << epsilon_ << " ) * scale + B;\n "
88+ << " " << output.SetByOffset (" global_idx" , " value" ) << " \n " ;
89+
90+ return Status::OK ();
91+ }
92+
93+ template <bool is_nhwc>
94+ Status BatchNormalization<is_nhwc>::ComputeInternal(ComputeContext& context) const {
95+ if (training_mode_) {
96+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " BatchNormalization trainingMode is not supported yet." );
97+ }
98+
99+ if (context.InputCount () != 5 ) {
100+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " BatchNormalization requires 5 inputs." );
101+ }
102+
103+ const auto * input_tensor = context.Input (0 );
104+ const TensorShape& input_shape = input_tensor->Shape ();
105+ size_t input_rank = input_shape.NumDimensions ();
106+ const int components = spatial_ ? ((input_shape[input_rank - 1 ] % 4 == 0 ) ? 4 : ((input_shape[input_rank - 1 ] % 2 == 0 ) ? 2 : 1 )) : 1 ;
107+
108+ auto output_dims = input_shape.AsShapeVector ();
109+ TensorShape output_shape (output_dims);
110+ auto * output_tensor = context.Output (0 , output_shape);
111+ int64_t output_size = output_tensor->Shape ().Size () / static_cast <int64_t >(components);
112+
113+ if (output_size == 0 ) {
114+ return Status::OK ();
115+ }
116+
117+ const auto * scale = context.Input <Tensor>(1 );
118+ const auto * B = context.Input <Tensor>(2 );
119+ const auto * input_mean = context.Input <Tensor>(3 );
120+ const auto * input_var = context.Input <Tensor>(4 );
121+
122+ ORT_RETURN_IF_ERROR (BatchNormHelper::ValidateInputs (input_tensor, scale, B, input_mean, input_var, spatial_ == 1 , format_ == DataLayout::NHWC));
123+
124+ BatchNormalizationProgram program{epsilon_, spatial_, format_, static_cast <int64_t >(components)};
125+ program
126+ .AddInputs ({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
127+ {scale, ProgramTensorMetadataDependency::TypeAndRank},
128+ {B, ProgramTensorMetadataDependency::TypeAndRank},
129+ {input_mean, ProgramTensorMetadataDependency::TypeAndRank},
130+ {input_var, ProgramTensorMetadataDependency::TypeAndRank}})
131+ .AddOutputs ({output_tensor})
132+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
133+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)}});
134+ return context.RunProgram (program);
135+ }
136+
137+ } // namespace webgpu
138+ } // namespace onnxruntime
0 commit comments