33
44#pragma once
55
6+ #include < algorithm>
67#include " core/framework/tensor_shape.h"
78#include " core/common/status.h"
89#include " core/common/narrow.h"
10+ #include " core/common/inlined_containers.h"
11+ #include " core/providers/cpu/nn/layer_norm_macro.h"
912
1013namespace onnxruntime {
1114
1215constexpr const char * kLayerNormInputShapeMismatchError =
13- " Size of scale and bias (if provided) must match X.shape[axis:], "
14- " or scale and bias (with same shape) can be broadcasted to X when axis is 2." ;
16+ " Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it." ;
1517
16- constexpr const char * kLayerNormInvalidSize = " Size of X.shape[axis:] must be larger than 1, got " ;
18+ constexpr const char * kLayerNormInvalidSize = " Size of X.shape[axis:] must be at least 1, got " ;
1719
1820constexpr int64_t kLayerNormInvalidInput = -1 ;
1921
@@ -23,22 +25,29 @@ struct LayerNormParams {
2325 int64_t scale_size;
2426 int64_t bias_size;
2527 int64_t broadcast_param;
28+ bool use_generic_broadcast{false }; // true: full NumPy-style broadcast; false: legacy broadcast_param path
29+ onnxruntime::InlinedVector<int64_t , 8 > x_dims;
30+ onnxruntime::InlinedVector<int64_t , 8 > x_inner_dims; // X.shape[axis:]
31+ onnxruntime::InlinedVector<int64_t , 8 > scale_dims;
32+ onnxruntime::InlinedVector<int64_t , 8 > bias_dims;
33+ onnxruntime::InlinedVector<int64_t , 8 > scale_strides;
34+ onnxruntime::InlinedVector<int64_t , 8 > bias_strides;
35+ int64_t axis{0 };
36+ int64_t last_rank{0 };
37+ onnxruntime::InlinedVector<int64_t , 8 > scale_inner_inc; // scale strides for inner dims [axis..]
38+ onnxruntime::InlinedVector<int64_t , 8 > bias_inner_inc; // bias strides for inner dims [axis..]
39+ onnxruntime::InlinedVector<int64_t , 8 > x_outer_strides; // X strides for outer dims [0..axis-1]
2640};
2741
28- // We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns.
42+ // Fast-path broadcasting for axis = 2:
2943// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S).
30- // We support scale and bias shape like below :
44+ // We support the following scale/ bias shapes in this path :
3145// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
3246// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
3347// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
3448// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
35-
36- // Below is a macro to compute the offset for scale and bias data for a row of X.
37- #ifndef LAYER_NORM_SCALE_BIAS_OFFSET
38- #define LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, x_row, norm_size ) \
39- ((broadcast_param == 0 ) ? 0 \
40- : norm_size * (broadcast_param > 0 ? x_row / broadcast_param : x_row % (-broadcast_param)))
41- #endif
49+ // For all other NumPy-broadcastable shapes we fall back to the generic
50+ // broadcasting path (use_generic_broadcast = true) and ignore broadcast_param.
4251
4352class LayerNormHelper {
4453 public:
@@ -48,30 +57,158 @@ class LayerNormHelper {
4857 bool has_bias,
4958 int64_t axis,
5059 LayerNormParams& params) {
60+ // Initialize basic layout parameters: how many rows we have and how many elements
61+ // are normalized per row, as well as the total scale/bias sizes.
5162 params.num_rows = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
5263 params.norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
5364 params.scale_size = scale_shape.Size ();
54- params.bias_size = bias_shape.Size ();
65+ params.bias_size = has_bias ? bias_shape.Size () : 0 ;
66+
5567 params.broadcast_param = 0 ;
68+ params.axis = axis;
5669
57- if (params.norm_size <= 1 ) {
70+ // Allow norm_size == 1 (scalar normalization is valid according to ONNX spec).
71+ if (params.norm_size < 1 ) {
5872 params.broadcast_param = kLayerNormInvalidInput ;
5973 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize , params.norm_size );
6074 } else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size )) {
6175 params.broadcast_param = GetBroadcastParam (x_shape, scale_shape, has_bias ? &bias_shape : nullptr , axis);
62- if (params.broadcast_param == kLayerNormInvalidInput ) {
63- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
64- kLayerNormInputShapeMismatchError ,
65- " X.shape=" , x_shape,
66- " scale.shape=" , scale_shape,
67- " bias.shape=" , bias_shape,
68- " and axis=" , axis);
76+ // Try to encode simple (B, S, ...) layouts into broadcast_param so that the
77+ // fast-path can be used. If this fails, broadcast_param will be set to
78+ // kLayerNormInvalidInput and we may fall back to generic broadcasting later.
79+ }
80+ const size_t xr = x_shape.NumDimensions ();
81+ const size_t sr = scale_shape.NumDimensions ();
82+ const size_t br = has_bias ? bias_shape.NumDimensions () : 0 ;
83+
84+ if (sr > xr || (has_bias && br > xr)) {
85+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
86+ kLayerNormInputShapeMismatchError ,
87+ " Scale/Bias rank cannot exceed Input rank." );
88+ }
89+
90+ params.x_dims .clear ();
91+ params.x_dims .reserve (xr);
92+ for (size_t i = 0 ; i < xr; ++i) {
93+ params.x_dims .push_back (x_shape.GetDims ()[i]);
94+ }
95+
96+ // Right-align scale and bias shapes
97+ params.scale_dims .clear ();
98+ params.scale_dims .resize (xr, 1 );
99+ for (size_t i = 0 ; i < sr; ++i) {
100+ params.scale_dims [xr - 1 - i] = scale_shape.GetDims ()[sr - 1 - i];
101+ }
102+
103+ params.bias_dims .clear ();
104+ if (has_bias) {
105+ params.bias_dims .resize (xr, 1 );
106+ for (size_t i = 0 ; i < br; ++i) {
107+ params.bias_dims [xr - 1 - i] = bias_shape.GetDims ()[br - 1 - i];
108+ }
109+ }
110+
111+ // Validate broadcastability
112+ const bool sc_ok = IsNumpyBroadcastable (params.scale_dims , params.x_dims );
113+ const bool bi_ok = !has_bias || IsNumpyBroadcastable (params.bias_dims , params.x_dims );
114+ if (!sc_ok || !bi_ok) {
115+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
116+ kLayerNormInputShapeMismatchError ,
117+ " X.shape=" , x_shape,
118+ " scale.shape=" , scale_shape,
119+ " bias.shape=" , bias_shape,
120+ " and axis=" , axis);
121+ }
122+
123+ // Compute strides for scale/bias once
124+ params.scale_strides = MakeStrides (params.scale_dims );
125+ params.bias_strides .clear ();
126+ if (has_bias) {
127+ params.bias_strides = MakeStrides (params.bias_dims );
128+ }
129+
130+ // Detect dependency on outer dimensions [0..axis-1]
131+ bool outer_dep = false ;
132+ for (int64_t i = 0 ; i < axis; ++i) {
133+ const size_t idx = static_cast <size_t >(i);
134+ if (params.scale_strides [idx] != 0 ||
135+ (has_bias && params.bias_strides [idx] != 0 )) {
136+ outer_dep = true ;
137+ break ;
138+ }
139+ }
140+
141+ // Decide if we need the generic NumPy-style broadcasting path
142+ params.use_generic_broadcast = outer_dep || (params.broadcast_param == kLayerNormInvalidInput );
143+
144+ if (params.use_generic_broadcast ) {
145+ // Cache inner dims X.shape[axis:]
146+ params.last_rank = onnxruntime::narrow<int64_t >(xr) - axis;
147+ params.x_inner_dims .clear ();
148+ params.x_inner_dims .reserve (params.last_rank > 0 ? static_cast <size_t >(params.last_rank ) : 0 );
149+ for (size_t i = static_cast <size_t >(axis); i < xr; ++i) {
150+ params.x_inner_dims .push_back (params.x_dims [i]);
151+ }
152+
153+ // Precompute inner increments for scale/bias over [axis..]
154+ params.scale_inner_inc .clear ();
155+ params.bias_inner_inc .clear ();
156+ for (size_t i = static_cast <size_t >(axis); i < xr; ++i) {
157+ params.scale_inner_inc .push_back (params.scale_strides [i]);
158+ if (has_bias) {
159+ params.bias_inner_inc .push_back (params.bias_strides [i]);
160+ }
161+ }
162+
163+ // X outer strides [0..axis-1], used only in generic path
164+ params.x_outer_strides .clear ();
165+ params.x_outer_strides .resize (static_cast <size_t >(axis), 1 );
166+ if (axis > 1 ) {
167+ for (int64_t d = axis - 2 ; d >= 0 ; --d) {
168+ const size_t du = static_cast <size_t >(d);
169+ params.x_outer_strides [du] =
170+ params.x_outer_strides [du + 1 ] * params.x_dims [du + 1 ];
171+ }
69172 }
173+ } else {
174+ // Fast-path: we don't need inner/outer increments
175+ params.last_rank = 0 ;
176+ params.x_inner_dims .clear ();
177+ params.scale_inner_inc .clear ();
178+ params.bias_inner_inc .clear ();
179+ params.x_outer_strides .clear ();
70180 }
181+
71182 return Status::OK ();
72183 }
73184
74185 private:
186+ static bool IsNumpyBroadcastable (gsl::span<const int64_t > a,
187+ gsl::span<const int64_t > b) {
188+ ORT_ENFORCE (a.size () == b.size ());
189+ for (size_t k = 0 ; k < a.size (); ++k) {
190+ const int64_t ak = a[k];
191+ const int64_t bk = b[k];
192+ if (!(ak == 1 || ak == bk)) {
193+ return false ;
194+ }
195+ }
196+ return true ;
197+ }
198+ static InlinedVector<int64_t , 8 > MakeStrides (const InlinedVector<int64_t , 8 >& dims) {
199+ InlinedVector<int64_t , 8 > strides (dims.size (), 0 );
200+ if (dims.empty ()) return strides;
201+
202+ int64_t running = 1 ;
203+ for (ptrdiff_t i = dims.size () - 1 ; i >= 0 ; --i) {
204+ size_t idx = static_cast <size_t >(i);
205+ strides[idx] = (dims[idx] == 1 ) ? 0 : running;
206+ running *= std::max<int64_t >(1 , dims[idx]);
207+ }
208+
209+ return strides;
210+ }
211+
75212 static int64_t GetBroadcastParam (const TensorShape& x_shape,
76213 const TensorShape& scale_shape,
77214 const TensorShape* bias_shape,
0 commit comments