@@ -26,10 +26,12 @@ namespace onnxruntime {
2626REGISTER_ONNX_KERNEL_TYPED_VERSIONED (float )
2727REGISTER_ONNX_KERNEL_TYPED_VERSIONED(double )
2828REGISTER_ONNX_KERNEL_TYPED_VERSIONED(MLFloat16)
29+ REGISTER_ONNX_KERNEL_TYPED_VERSIONED(BFloat16)
2930
3031REGISTER_ONNX_KERNEL_TYPED_21(float )
3132REGISTER_ONNX_KERNEL_TYPED_21(double )
3233REGISTER_ONNX_KERNEL_TYPED_21(MLFloat16)
34+ REGISTER_ONNX_KERNEL_TYPED_21(BFloat16)
3335
3436GroupNorm::GroupNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
3537 ORT_ENFORCE (op_kernel_info.GetAttr (" epsilon" , &epsilon_).IsOK ());
@@ -86,25 +88,29 @@ Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const T
8688 const int64_t group_start_channel = group_idx * channels_per_group;
8789 const int64_t group_end_channel = group_start_channel + channels_per_group;
8890
89- // Calculate mean and variance for this group
91+ // Stage 1: Calculate mean and variance in the precision specified by stash_type
92+ // According to ONNX spec, when stash_type=1, cast to float32 for computation
9093 double sum = 0.0 ;
9194 double sum_sq = 0.0 ;
9295 const int64_t group_size = channels_per_group * spatial_size;
9396
97+ // Stage 1: Compute mean and variance (using float32 precision when stash_type=1)
9498 for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
9599 const T* channel_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
96100 for (int64_t s = 0 ; s < spatial_size; ++s) {
97- const double val = static_cast <double >(channel_data[s]);
98- sum += val;
99- sum_sq += val * val;
101+ // Cast to float for precision as per stash_type=1 specification
102+ const float val = static_cast <float >(channel_data[s]);
103+ sum += static_cast <double >(val);
104+ sum_sq += static_cast <double >(val * val);
100105 }
101106 }
102107
103108 const double mean = sum / group_size;
104109 const double variance = sum_sq / group_size - mean * mean;
105110 const double inv_std = 1.0 / std::sqrt (variance + static_cast <double >(epsilon_));
106111
107- // Apply normalization: y = scale * (x - mean) / std + bias
112+ // Stage 2: Apply normalization with scale and bias (in original precision)
113+ // y = scale * (x - mean) / sqrt(variance + epsilon) + bias
108114 for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
109115 const T* channel_x_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
110116 T* channel_y_data = y_data + batch_idx * C * spatial_size + c * spatial_size;
@@ -113,9 +119,13 @@ Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const T
113119 const T bias_val = bias_data[c];
114120
115121 for (int64_t s = 0 ; s < spatial_size; ++s) {
116- const double normalized = (static_cast <double >(channel_x_data[s]) - mean) * inv_std;
117- const double result = normalized * static_cast <double >(scale_val) + static_cast <double >(bias_val);
118- channel_y_data[s] = static_cast <T>(static_cast <float >(result));
122+ // Normalize using float32 precision as per stash_type=1
123+ const float x_float = static_cast <float >(channel_x_data[s]);
124+ const float normalized = (x_float - static_cast <float >(mean)) * static_cast <float >(inv_std);
125+
126+ // Apply scale and bias in original type precision
127+ const float result = normalized * static_cast <float >(scale_val) + static_cast <float >(bias_val);
128+ channel_y_data[s] = static_cast <T>(result);
119129 }
120130 }
121131 },
@@ -133,9 +143,12 @@ Status GroupNorm::ComputeHelper(OpKernelContext* context, const Tensor* X, const
133143 return ComputeImpl<double >(context, X, scale, bias);
134144 } else if (element_type == DataTypeImpl::GetType<MLFloat16>()) {
135145 return ComputeImpl<MLFloat16>(context, X, scale, bias);
146+ } else if (element_type == DataTypeImpl::GetType<BFloat16>()) {
147+ return ComputeImpl<BFloat16>(context, X, scale, bias);
136148 }
137149
138- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " GroupNorm only supports float, double, and float16 data types" );
150+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
151+ " GroupNorm only supports float, double, float16, and bfloat16 data types" );
139152}
140153
141154} // namespace onnxruntime
0 commit comments