Skip to content

Commit db72187

Browse files
Copilotjustinchuby
andcommitted
Changes before error encountered
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent adee270 commit db72187

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Op
10981098

10991099
#endif
11001100

1101+
// GroupNormalization (opset 18 and 21)
1102+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, float, GroupNormalization);
1103+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, double, GroupNormalization);
1104+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, MLFloat16, GroupNormalization);
1105+
11011106
// Opset 19
11021107
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Size);
11031108
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, AveragePool);
@@ -1239,6 +1244,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
12391244
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul);
12401245
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul);
12411246

1247+
// GroupNormalization opset 21 (with stash_type attribute)
1248+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, float, GroupNormalization);
1249+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, double, GroupNormalization);
1250+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, MLFloat16, GroupNormalization);
1251+
12421252
// Opset 22
12431253
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Acos);
12441254
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Cos);
@@ -3018,6 +3028,11 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
30183028
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalGetElement)>,
30193029
#endif
30203030

3031+
// GroupNormalization (opset 18-20)
3032+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, float, GroupNormalization)>,
3033+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, double, GroupNormalization)>,
3034+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 20, MLFloat16, GroupNormalization)>,
3035+
30213036
// Opset 19
30223037
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Size)>,
30233038
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, AveragePool)>,
@@ -3203,6 +3218,11 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
32033218
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t,
32043219
QLinearMatMul)>,
32053220

3221+
// GroupNormalization (opset 21 with stash_type attribute)
3222+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, float, GroupNormalization)>,
3223+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, double, GroupNormalization)>,
3224+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, MLFloat16, GroupNormalization)>,
3225+
32063226
// Opset 22
32073227
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Cos)>,
32083228
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Tan)>,
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cpu/nn/group_norm.h"
5+
#include "core/providers/common.h"
6+
#include "core/util/math.h"
7+
#include "core/util/math_cpuonly.h"
8+
#include "core/platform/threadpool.h"
9+
10+
namespace onnxruntime {
11+
12+
// Opset 18-20 registrations (without stash_type)
13+
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \
14+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \
15+
KernelDefBuilder() \
16+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
17+
GroupNorm);
18+
19+
// Opset 21+ registrations (with stash_type)
20+
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \
21+
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \
22+
KernelDefBuilder() \
23+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
24+
GroupNorm);
25+
26+
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(float)
27+
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(double)
28+
REGISTER_ONNX_KERNEL_TYPED_VERSIONED(MLFloat16)
29+
30+
REGISTER_ONNX_KERNEL_TYPED_21(float)
31+
REGISTER_ONNX_KERNEL_TYPED_21(double)
32+
REGISTER_ONNX_KERNEL_TYPED_21(MLFloat16)
33+
34+
GroupNorm::GroupNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
35+
ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK());
36+
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());
37+
38+
// stash_type is optional in opset 21, default to 1 (float32)
39+
if (!op_kernel_info.GetAttr("stash_type", &stash_type_).IsOK()) {
40+
stash_type_ = 1;
41+
}
42+
}
43+
44+
Status GroupNorm::Compute(OpKernelContext* context) const {
45+
const Tensor* X = context->Input<Tensor>(0);
46+
const Tensor* scale = context->Input<Tensor>(1);
47+
const Tensor* bias = context->Input<Tensor>(2);
48+
49+
ORT_RETURN_IF_ERROR(ComputeHelper(context, X, scale, bias));
50+
return Status::OK();
51+
}
52+
53+
template<typename T>
54+
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {
55+
const auto& x_shape = X->Shape();
56+
const int64_t N = x_shape[0]; // batch size
57+
const int64_t C = x_shape[1]; // channels
58+
59+
// Validate that channels are divisible by num_groups
60+
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");
61+
62+
const int64_t channels_per_group = C / num_groups_;
63+
64+
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)
65+
int64_t spatial_size = 1;
66+
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
67+
spatial_size *= x_shape[i];
68+
}
69+
70+
Tensor* Y = context->Output(0, x_shape);
71+
72+
const T* x_data = X->Data<T>();
73+
const T* scale_data = scale->Data<T>();
74+
const T* bias_data = bias->Data<T>();
75+
T* y_data = Y->MutableData<T>();
76+
77+
// Process each batch and group
78+
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
79+
80+
concurrency::ThreadPool::TryBatchParallelFor(
81+
tp, static_cast<int32_t>(N * num_groups_),
82+
[&](ptrdiff_t idx) {
83+
const int64_t batch_idx = idx / num_groups_;
84+
const int64_t group_idx = idx % num_groups_;
85+
86+
const int64_t group_start_channel = group_idx * channels_per_group;
87+
const int64_t group_end_channel = group_start_channel + channels_per_group;
88+
89+
// Calculate mean and variance for this group
90+
double sum = 0.0;
91+
double sum_sq = 0.0;
92+
const int64_t group_size = channels_per_group * spatial_size;
93+
94+
for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
95+
const T* channel_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
96+
for (int64_t s = 0; s < spatial_size; ++s) {
97+
const double val = static_cast<double>(channel_data[s]);
98+
sum += val;
99+
sum_sq += val * val;
100+
}
101+
}
102+
103+
const double mean = sum / group_size;
104+
const double variance = sum_sq / group_size - mean * mean;
105+
const double inv_std = 1.0 / std::sqrt(variance + static_cast<double>(epsilon_));
106+
107+
// Apply normalization: y = scale * (x - mean) / std + bias
108+
for (int64_t c = group_start_channel; c < group_end_channel; ++c) {
109+
const T* channel_x_data = x_data + batch_idx * C * spatial_size + c * spatial_size;
110+
T* channel_y_data = y_data + batch_idx * C * spatial_size + c * spatial_size;
111+
112+
const T scale_val = scale_data[c];
113+
const T bias_val = bias_data[c];
114+
115+
for (int64_t s = 0; s < spatial_size; ++s) {
116+
const double normalized = (static_cast<double>(channel_x_data[s]) - mean) * inv_std;
117+
const double result = normalized * static_cast<double>(scale_val) + static_cast<double>(bias_val);
118+
channel_y_data[s] = static_cast<T>(static_cast<float>(result));
119+
}
120+
}
121+
},
122+
0);
123+
124+
return Status::OK();
125+
}
126+
127+
Status GroupNorm::ComputeHelper(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {
128+
const auto element_type = X->DataType();
129+
130+
if (element_type == DataTypeImpl::GetType<float>()) {
131+
return ComputeImpl<float>(context, X, scale, bias);
132+
} else if (element_type == DataTypeImpl::GetType<double>()) {
133+
return ComputeImpl<double>(context, X, scale, bias);
134+
} else if (element_type == DataTypeImpl::GetType<MLFloat16>()) {
135+
return ComputeImpl<MLFloat16>(context, X, scale, bias);
136+
}
137+
138+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "GroupNorm only supports float, double, and float16 data types");
139+
}
140+
141+
} // namespace onnxruntime
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
#include "core/framework/tensor.h"
9+
10+
namespace onnxruntime {
11+
12+
class GroupNorm final : public OpKernel {
13+
public:
14+
GroupNorm(const OpKernelInfo& op_kernel_info);
15+
Status Compute(OpKernelContext* context) const override;
16+
17+
private:
18+
template<typename T>
19+
Status ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const;
20+
21+
Status ComputeHelper(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const;
22+
23+
float epsilon_;
24+
int64_t num_groups_;
25+
int64_t stash_type_;
26+
};
27+
28+
} // namespace onnxruntime

0 commit comments

Comments
 (0)