Skip to content

Commit 5c5a8ce

Browse files
authored
Add full broadcasting support to LayerNormalization and RMSNormalization (#26613)
### Description <!-- Describe your changes. --> This PR adds full and spec-compliant broadcasting support to both LayerNormalization and RMSNormalization. Previously, onnxruntime supported only a partial set of broadcasting cases (based on the logic introduced in this PR #23297 ). That implementation handled several cases but did not cover all valid broadcasting scenarios. This PR introduces a complete generic broadcasting path, following the [ONNX specification rules](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md). The previous implementation is preserved as a fast-path and is still used whenever the Scale/Bias shapes match directly. Main changes: - Extended broadcasting logic in: layer_norm_helper.h layer_norm_impl.cc - Added full support for all valid broadcasting configurations of Scale and Bias. - Preserved previous partial logic as a fast-path for exact-match cases. - Added comprehensive tests to: layer_norm_op_test.cc rms_norm_op_test.cc ### Motivation and Context <!-- - Why is this change required? What problem does it solve?--> Before this fix, some valid ONNX broadcasting shapes were rejected in LayerNormalization and RMSNormalization. This PR brings the operators into full alignment with the ONNX specification and fixes models that previously failed due to incomplete broadcasting support. Fixes #26432 Fixes #18184 <!-- -If it fixes an open issue, please link to the issue here. -->
1 parent 3c06f58 commit 5c5a8ce

File tree

6 files changed

+1486
-40
lines changed

6 files changed

+1486
-40
lines changed

onnxruntime/core/providers/cpu/nn/layer_norm_helper.h

Lines changed: 158 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33

44
#pragma once
55

6+
#include <algorithm>
67
#include "core/framework/tensor_shape.h"
78
#include "core/common/status.h"
89
#include "core/common/narrow.h"
10+
#include "core/common/inlined_containers.h"
11+
#include "core/providers/cpu/nn/layer_norm_macro.h"
912

1013
namespace onnxruntime {
1114

1215
constexpr const char* kLayerNormInputShapeMismatchError =
13-
"Size of scale and bias (if provided) must match X.shape[axis:], "
14-
"or scale and bias (with same shape) can be broadcasted to X when axis is 2.";
16+
"Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it.";
1517

16-
constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";
18+
constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be at least 1, got ";
1719

1820
constexpr int64_t kLayerNormInvalidInput = -1;
1921

@@ -23,22 +25,29 @@ struct LayerNormParams {
2325
int64_t scale_size;
2426
int64_t bias_size;
2527
int64_t broadcast_param;
28+
bool use_generic_broadcast{false}; // true: full NumPy-style broadcast; false: legacy broadcast_param path
29+
onnxruntime::InlinedVector<int64_t, 8> x_dims;
30+
onnxruntime::InlinedVector<int64_t, 8> x_inner_dims; // X.shape[axis:]
31+
onnxruntime::InlinedVector<int64_t, 8> scale_dims;
32+
onnxruntime::InlinedVector<int64_t, 8> bias_dims;
33+
onnxruntime::InlinedVector<int64_t, 8> scale_strides;
34+
onnxruntime::InlinedVector<int64_t, 8> bias_strides;
35+
int64_t axis{0};
36+
int64_t last_rank{0};
37+
onnxruntime::InlinedVector<int64_t, 8> scale_inner_inc; // scale strides for inner dims [axis..]
38+
onnxruntime::InlinedVector<int64_t, 8> bias_inner_inc; // bias strides for inner dims [axis..]
39+
onnxruntime::InlinedVector<int64_t, 8> x_outer_strides; // X strides for outer dims [0..axis-1]
2640
};
2741

28-
// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns.
42+
// Fast-path broadcasting for axis = 2:
2943
// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S).
30-
// We support scale and bias shape like below:
44+
// We support the following scale/bias shapes in this path:
3145
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
3246
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
3347
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
3448
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
35-
36-
// Below is a macro to compute the offset for scale and bias data for a row of X.
37-
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
38-
#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, x_row, norm_size) \
39-
((broadcast_param == 0) ? 0 \
40-
: norm_size * (broadcast_param > 0 ? x_row / broadcast_param : x_row % (-broadcast_param)))
41-
#endif
49+
// For all other NumPy-broadcastable shapes we fall back to the generic
50+
// broadcasting path (use_generic_broadcast = true) and ignore broadcast_param.
4251

4352
class LayerNormHelper {
4453
public:
@@ -48,30 +57,158 @@ class LayerNormHelper {
4857
bool has_bias,
4958
int64_t axis,
5059
LayerNormParams& params) {
60+
// Initialize basic layout parameters: how many rows we have and how many elements
61+
// are normalized per row, as well as the total scale/bias sizes.
5162
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
5263
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
5364
params.scale_size = scale_shape.Size();
54-
params.bias_size = bias_shape.Size();
65+
params.bias_size = has_bias ? bias_shape.Size() : 0;
66+
5567
params.broadcast_param = 0;
68+
params.axis = axis;
5669

57-
if (params.norm_size <= 1) {
70+
// Allow norm_size == 1 (scalar normalization is valid according to ONNX spec).
71+
if (params.norm_size < 1) {
5872
params.broadcast_param = kLayerNormInvalidInput;
5973
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
6074
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
6175
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
62-
if (params.broadcast_param == kLayerNormInvalidInput) {
63-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
64-
kLayerNormInputShapeMismatchError,
65-
" X.shape=", x_shape,
66-
" scale.shape=", scale_shape,
67-
" bias.shape=", bias_shape,
68-
" and axis=", axis);
76+
// Try to encode simple (B, S, ...) layouts into broadcast_param so that the
77+
// fast-path can be used. If this fails, broadcast_param will be set to
78+
// kLayerNormInvalidInput and we may fall back to generic broadcasting later.
79+
}
80+
const size_t xr = x_shape.NumDimensions();
81+
const size_t sr = scale_shape.NumDimensions();
82+
const size_t br = has_bias ? bias_shape.NumDimensions() : 0;
83+
84+
if (sr > xr || (has_bias && br > xr)) {
85+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
86+
kLayerNormInputShapeMismatchError,
87+
" Scale/Bias rank cannot exceed Input rank.");
88+
}
89+
90+
params.x_dims.clear();
91+
params.x_dims.reserve(xr);
92+
for (size_t i = 0; i < xr; ++i) {
93+
params.x_dims.push_back(x_shape.GetDims()[i]);
94+
}
95+
96+
// Right-align scale and bias shapes
97+
params.scale_dims.clear();
98+
params.scale_dims.resize(xr, 1);
99+
for (size_t i = 0; i < sr; ++i) {
100+
params.scale_dims[xr - 1 - i] = scale_shape.GetDims()[sr - 1 - i];
101+
}
102+
103+
params.bias_dims.clear();
104+
if (has_bias) {
105+
params.bias_dims.resize(xr, 1);
106+
for (size_t i = 0; i < br; ++i) {
107+
params.bias_dims[xr - 1 - i] = bias_shape.GetDims()[br - 1 - i];
108+
}
109+
}
110+
111+
// Validate broadcastability
112+
const bool sc_ok = IsNumpyBroadcastable(params.scale_dims, params.x_dims);
113+
const bool bi_ok = !has_bias || IsNumpyBroadcastable(params.bias_dims, params.x_dims);
114+
if (!sc_ok || !bi_ok) {
115+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
116+
kLayerNormInputShapeMismatchError,
117+
" X.shape=", x_shape,
118+
" scale.shape=", scale_shape,
119+
" bias.shape=", bias_shape,
120+
" and axis=", axis);
121+
}
122+
123+
// Compute strides for scale/bias once
124+
params.scale_strides = MakeStrides(params.scale_dims);
125+
params.bias_strides.clear();
126+
if (has_bias) {
127+
params.bias_strides = MakeStrides(params.bias_dims);
128+
}
129+
130+
// Detect dependency on outer dimensions [0..axis-1]
131+
bool outer_dep = false;
132+
for (int64_t i = 0; i < axis; ++i) {
133+
const size_t idx = static_cast<size_t>(i);
134+
if (params.scale_strides[idx] != 0 ||
135+
(has_bias && params.bias_strides[idx] != 0)) {
136+
outer_dep = true;
137+
break;
138+
}
139+
}
140+
141+
// Decide if we need the generic NumPy-style broadcasting path
142+
params.use_generic_broadcast = outer_dep || (params.broadcast_param == kLayerNormInvalidInput);
143+
144+
if (params.use_generic_broadcast) {
145+
// Cache inner dims X.shape[axis:]
146+
params.last_rank = onnxruntime::narrow<int64_t>(xr) - axis;
147+
params.x_inner_dims.clear();
148+
params.x_inner_dims.reserve(params.last_rank > 0 ? static_cast<size_t>(params.last_rank) : 0);
149+
for (size_t i = static_cast<size_t>(axis); i < xr; ++i) {
150+
params.x_inner_dims.push_back(params.x_dims[i]);
151+
}
152+
153+
// Precompute inner increments for scale/bias over [axis..]
154+
params.scale_inner_inc.clear();
155+
params.bias_inner_inc.clear();
156+
for (size_t i = static_cast<size_t>(axis); i < xr; ++i) {
157+
params.scale_inner_inc.push_back(params.scale_strides[i]);
158+
if (has_bias) {
159+
params.bias_inner_inc.push_back(params.bias_strides[i]);
160+
}
161+
}
162+
163+
// X outer strides [0..axis-1], used only in generic path
164+
params.x_outer_strides.clear();
165+
params.x_outer_strides.resize(static_cast<size_t>(axis), 1);
166+
if (axis > 1) {
167+
for (int64_t d = axis - 2; d >= 0; --d) {
168+
const size_t du = static_cast<size_t>(d);
169+
params.x_outer_strides[du] =
170+
params.x_outer_strides[du + 1] * params.x_dims[du + 1];
171+
}
69172
}
173+
} else {
174+
// Fast-path: we don't need inner/outer increments
175+
params.last_rank = 0;
176+
params.x_inner_dims.clear();
177+
params.scale_inner_inc.clear();
178+
params.bias_inner_inc.clear();
179+
params.x_outer_strides.clear();
70180
}
181+
71182
return Status::OK();
72183
}
73184

74185
private:
186+
static bool IsNumpyBroadcastable(gsl::span<const int64_t> a,
187+
gsl::span<const int64_t> b) {
188+
ORT_ENFORCE(a.size() == b.size());
189+
for (size_t k = 0; k < a.size(); ++k) {
190+
const int64_t ak = a[k];
191+
const int64_t bk = b[k];
192+
if (!(ak == 1 || ak == bk)) {
193+
return false;
194+
}
195+
}
196+
return true;
197+
}
198+
static InlinedVector<int64_t, 8> MakeStrides(const InlinedVector<int64_t, 8>& dims) {
199+
InlinedVector<int64_t, 8> strides(dims.size(), 0);
200+
if (dims.empty()) return strides;
201+
202+
int64_t running = 1;
203+
for (ptrdiff_t i = dims.size() - 1; i >= 0; --i) {
204+
size_t idx = static_cast<size_t>(i);
205+
strides[idx] = (dims[idx] == 1) ? 0 : running;
206+
running *= std::max<int64_t>(1, dims[idx]);
207+
}
208+
209+
return strides;
210+
}
211+
75212
static int64_t GetBroadcastParam(const TensorShape& x_shape,
76213
const TensorShape& scale_shape,
77214
const TensorShape* bias_shape,

0 commit comments

Comments
 (0)