Skip to content

Commit 586e758

Browse files
Copilotjustinchuby
andcommitted
Implement ONNX GroupNormalization-21 with proper stash_type and BFloat16 support
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent db72187 commit 586e758

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Op
11021102
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, float, GroupNormalization);
11031103
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, double, GroupNormalization);
11041104
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, MLFloat16, GroupNormalization);
1105+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, BFloat16, GroupNormalization);
11051106

11061107
// Opset 19
11071108
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Size);
@@ -1248,6 +1249,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
12481249
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, float, GroupNormalization);
12491250
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, double, GroupNormalization);
12501251
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, MLFloat16, GroupNormalization);
1252+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, BFloat16, GroupNormalization);
12511253

12521254
// Opset 22
12531255
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Acos);
@@ -3032,6 +3034,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
30323034
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, float, GroupNormalization)>,
30333035
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, double, GroupNormalization)>,
30343036
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, MLFloat16, GroupNormalization)>,
3037+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, BFloat16, GroupNormalization)>,
30353038

30363039
// Opset 19
30373040
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Size)>,
@@ -3222,6 +3225,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
32223225
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, float, GroupNormalization)>,
32233226
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, double, GroupNormalization)>,
32243227
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, MLFloat16, GroupNormalization)>,
3228+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, BFloat16, GroupNormalization)>,
32253229

32263230
// Opset 22
32273231
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Cos)>,

onnxruntime/core/providers/cpu/nn/group_norm.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ namespace onnxruntime {
2626
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(float)
2727
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(double)
2828
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(MLFloat16)
29+
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(BFloat16)
2930

3031
REGISTER_ONNX_KERNEL_TYPED_21(float)
3132
REGISTER_ONNX_KERNEL_TYPED_21(double)
3233
REGISTER_ONNX_KERNEL_TYPED_21(MLFloat16)
34+
REGISTER_ONNX_KERNEL_TYPED_21(BFloat16)
3335

3436
GroupNorm::GroupNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
3537
ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK());
@@ -86,25 +88,29 @@ Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const T
8688
const int64_t group_start_channel = group_idx * channels_per_group;
8789
const int64_t group_end_channel = group_start_channel + channels_per_group;
8890

89-
// Calculate mean and variance for this group
91+
// Stage 1: Calculate mean and variance in the precision specified by stash_type
92+
// According to ONNX spec, when stash_type=1, cast to float32 for computation
9093
double sum = 0.0;
9194
double sum_sq = 0.0;
9295
const int64_t group_size = channels_per_group * spatial_size;
9396

97+
// Stage 1: Compute mean and variance (using float32 precision when stash_type=1)
9498
for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
9599
const T* channel_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
96100
for (int64_t s = 0; s < spatial_size; ++s) {
97-
const double val = static_cast<double>(channel_data[s]);
98-
sum += val;
99-
sum_sq += val * val;
101+
// Cast to float for precision as per stash_type=1 specification
102+
const float val = static_cast<float>(channel_data[s]);
103+
sum += static_cast<double>(val);
104+
sum_sq += static_cast<double>(val * val);
100105
}
101106
}
102107

103108
const double mean = sum / group_size;
104109
const double variance = sum_sq / group_size - mean * mean;
105110
const double inv_std = 1.0 / std::sqrt(variance + static_cast<double>(epsilon_));
106111

107-
// Apply normalization: y = scale * (x - mean) / std + bias
112+
// Stage 2: Apply normalization with scale and bias (in original precision)
113+
// y = scale * (x - mean) / sqrt(variance + epsilon) + bias
108114
for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
109115
const T* channel_x_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
110116
T* channel_y_data = y_data + batch_idx * C * spatial_size + c * spatial_size;
@@ -113,9 +119,13 @@ Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const T
113119
const T bias_val = bias_data[c];
114120

115121
for (int64_t s = 0; s < spatial_size; ++s) {
116-
const double normalized = (static_cast<double>(channel_x_data[s]) - mean) * inv_std;
117-
const double result = normalized * static_cast<double>(scale_val) + static_cast<double>(bias_val);
118-
channel_y_data[s] = static_cast<T>(static_cast<float>(result));
122+
// Normalize using float32 precision as per stash_type=1
123+
const float x_float = static_cast<float>(channel_x_data[s]);
124+
const float normalized = (x_float - static_cast<float>(mean)) * static_cast<float>(inv_std);
125+
126+
// Apply scale and bias in original type precision
127+
const float result = normalized * static_cast<float>(scale_val) + static_cast<float>(bias_val);
128+
channel_y_data[s] = static_cast<T>(result);
119129
}
120130
}
121131
},
@@ -133,9 +143,12 @@ Status GroupNorm::ComputeHelper(OpKernelContext* context, const Tensor* X, const
133143
return ComputeImpl<double>(context, X, scale, bias);
134144
} else if (element_type == DataTypeImpl::GetType<MLFloat16>()) {
135145
return ComputeImpl<MLFloat16>(context, X, scale, bias);
146+
} else if (element_type == DataTypeImpl::GetType<BFloat16>()) {
147+
return ComputeImpl<BFloat16>(context, X, scale, bias);
136148
}
137149

138-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "GroupNorm only supports float, double, and float16 data types");
150+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
151+
"GroupNorm only supports float, double, float16, and bfloat16 data types");
139152
}
140153

141154
} // namespace onnxruntime

0 commit comments

Comments
 (0)