1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/providers/cpu/nn/group_norm.h"
5+ #include " core/providers/common.h"
6+ #include " core/util/math.h"
7+ #include " core/util/math_cpuonly.h"
8+ #include " core/platform/threadpool.h"
9+
10+ namespace onnxruntime {
11+
12+ // Opset 18-20 registrations (without stash_type)
13+ #define REGISTER_ONNX_KERNEL_TYPED_VERSIONED (T ) \
14+ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL (GroupNormalization, 18 , 20 , T, \
15+ KernelDefBuilder () \
16+ .TypeConstraint(" T" , DataTypeImpl::GetTensorType<T>()), \
17+ GroupNorm);
18+
19+ // Opset 21+ registrations (with stash_type)
20+ #define REGISTER_ONNX_KERNEL_TYPED_21 (T ) \
21+ ONNX_CPU_OPERATOR_TYPED_KERNEL (GroupNormalization, 21 , T, \
22+ KernelDefBuilder () \
23+ .TypeConstraint(" T" , DataTypeImpl::GetTensorType<T>()), \
24+ GroupNorm);
25+
26+ REGISTER_ONNX_KERNEL_TYPED_VERSIONED (float )
27+ REGISTER_ONNX_KERNEL_TYPED_VERSIONED(double )
28+ REGISTER_ONNX_KERNEL_TYPED_VERSIONED(MLFloat16)
29+
30+ REGISTER_ONNX_KERNEL_TYPED_21(float )
31+ REGISTER_ONNX_KERNEL_TYPED_21(double )
32+ REGISTER_ONNX_KERNEL_TYPED_21(MLFloat16)
33+
34+ GroupNorm::GroupNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
35+ ORT_ENFORCE (op_kernel_info.GetAttr (" epsilon" , &epsilon_).IsOK ());
36+ ORT_ENFORCE (op_kernel_info.GetAttr (" num_groups" , &num_groups_).IsOK ());
37+
38+ // stash_type is optional in opset 21, default to 1 (float32)
39+ if (!op_kernel_info.GetAttr (" stash_type" , &stash_type_).IsOK ()) {
40+ stash_type_ = 1 ;
41+ }
42+ }
43+
44+ Status GroupNorm::Compute (OpKernelContext* context) const {
45+ const Tensor* X = context->Input <Tensor>(0 );
46+ const Tensor* scale = context->Input <Tensor>(1 );
47+ const Tensor* bias = context->Input <Tensor>(2 );
48+
49+ ORT_RETURN_IF_ERROR (ComputeHelper (context, X, scale, bias));
50+ return Status::OK ();
51+ }
52+
53+ template <typename T>
54+ Status GroupNorm::ComputeImpl (OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {
55+ const auto & x_shape = X->Shape ();
56+ const int64_t N = x_shape[0 ]; // batch size
57+ const int64_t C = x_shape[1 ]; // channels
58+
59+ // Validate that channels are divisible by num_groups
60+ ORT_RETURN_IF_NOT (C % num_groups_ == 0 , " Number of channels must be divisible by num_groups" );
61+
62+ const int64_t channels_per_group = C / num_groups_;
63+
64+ // Calculate spatial dimensions (H*W*... for everything after batch and channel dims)
65+ int64_t spatial_size = 1 ;
66+ for (size_t i = 2 ; i < x_shape.NumDimensions (); ++i) {
67+ spatial_size *= x_shape[i];
68+ }
69+
70+ Tensor* Y = context->Output (0 , x_shape);
71+
72+ const T* x_data = X->Data <T>();
73+ const T* scale_data = scale->Data <T>();
74+ const T* bias_data = bias->Data <T>();
75+ T* y_data = Y->MutableData <T>();
76+
77+ // Process each batch and group
78+ concurrency::ThreadPool* tp = context->GetOperatorThreadPool ();
79+
80+ concurrency::ThreadPool::TryBatchParallelFor (
81+ tp, static_cast <int32_t >(N * num_groups_),
82+ [&](ptrdiff_t idx) {
83+ const int64_t batch_idx = idx / num_groups_;
84+ const int64_t group_idx = idx % num_groups_;
85+
86+ const int64_t group_start_channel = group_idx * channels_per_group;
87+ const int64_t group_end_channel = group_start_channel + channels_per_group;
88+
89+ // Calculate mean and variance for this group
90+ double sum = 0.0 ;
91+ double sum_sq = 0.0 ;
92+ const int64_t group_size = channels_per_group * spatial_size;
93+
94+ for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
95+ const T* channel_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
96+ for (int64_t s = 0 ; s < spatial_size; ++s) {
97+ const double val = static_cast <double >(channel_data[s]);
98+ sum += val;
99+ sum_sq += val * val;
100+ }
101+ }
102+
103+ const double mean = sum / group_size;
104+ const double variance = sum_sq / group_size - mean * mean;
105+ const double inv_std = 1.0 / std::sqrt (variance + static_cast <double >(epsilon_));
106+
107+ // Apply normalization: y = scale * (x - mean) / std + bias
108+ for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
109+ const T* channel_x_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
110+ T* channel_y_data = y_data + batch_idx * C * spatial_size + c * spatial_size;
111+
112+ const T scale_val = scale_data[c];
113+ const T bias_val = bias_data[c];
114+
115+ for (int64_t s = 0 ; s < spatial_size; ++s) {
116+ const double normalized = (static_cast <double >(channel_x_data[s]) - mean) * inv_std;
117+ const double result = normalized * static_cast <double >(scale_val) + static_cast <double >(bias_val);
118+ channel_y_data[s] = static_cast <T>(static_cast <float >(result));
119+ }
120+ }
121+ },
122+ 0 );
123+
124+ return Status::OK ();
125+ }
126+
127+ Status GroupNorm::ComputeHelper (OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {
128+ const auto element_type = X->DataType ();
129+
130+ if (element_type == DataTypeImpl::GetType<float >()) {
131+ return ComputeImpl<float >(context, X, scale, bias);
132+ } else if (element_type == DataTypeImpl::GetType<double >()) {
133+ return ComputeImpl<double >(context, X, scale, bias);
134+ } else if (element_type == DataTypeImpl::GetType<MLFloat16>()) {
135+ return ComputeImpl<MLFloat16>(context, X, scale, bias);
136+ }
137+
138+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " GroupNorm only supports float, double, and float16 data types" );
139+ }
140+
141+ } // namespace onnxruntime
0 commit comments