22// Licensed under the MIT License.
33
44#include " layer_norm_impl.h"
5+ #include " layer_norm_helper.h"
56
67#include " core/common/safeint.h"
78#include " core/framework/tensor.h"
@@ -24,6 +25,7 @@ void ComputeJob(
2425 const T* bias_data,
2526 const ptrdiff_t task_idx,
2627 const int64_t norm_size,
28+ const int64_t broadcast_param,
2729 const float * scale_float_ptr,
2830 const float * bias_float_ptr,
2931 float epsilon,
@@ -55,13 +57,16 @@ void ComputeJob(
5557 mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
5658 }
5759
58- for (int64_t h = 0 ; h < norm_size; h++) {
60+ // Compute the offset of gamma and beta to support broadcasting.
61+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
62+
63+ for (int64_t h = 0 ; h < norm_size; h++, i++) {
5964 if (simplified) {
60- p_output[h] = p_output[h] / mean_square * scale_data[h ];
65+ p_output[h] = p_output[h] / mean_square * scale_data[i ];
6166 } else if (nullptr == bias_data) {
62- p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h ];
67+ p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i ];
6368 } else {
64- p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h ] + bias_data[h ];
69+ p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i ] + bias_data[i ];
6570 }
6671 }
6772
@@ -82,6 +87,7 @@ void ComputeJob(
8287 const MLFloat16* bias_data,
8388 const ptrdiff_t task_idx,
8489 const int64_t norm_size,
90+ const int64_t broadcast_param,
8591 const float * scale_float_ptr,
8692 const float * bias_float_ptr,
8793 float epsilon,
@@ -120,13 +126,16 @@ void ComputeJob(
120126 mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
121127 }
122128
123- for (size_t h = 0 ; h < num_elems; h++) {
129+ // Compute the offset of gamma and beta to support broadcasting.
130+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
131+
132+ for (size_t h = 0 ; h < num_elems; h++, i++) {
124133 if (simplified) {
125- output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h ];
134+ output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i ];
126135 } else if (nullptr == bias_float_ptr) {
127- output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h ];
136+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i ];
128137 } else {
129- output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h ] + bias_float_ptr[h ];
138+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i ] + bias_float_ptr[i ];
130139 }
131140 }
132141
@@ -161,9 +170,7 @@ LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified
161170 simplified_{simplified},
162171 contrib_op_{contrib_op},
163172 prepacked_scale_fp32_data_ (nullptr ),
164- prepacked_scale_fp32_size_ (0 ),
165- prepacked_bias_fp32_data_ (nullptr ),
166- prepacked_bias_fp32_size_ (0 ) {
173+ prepacked_bias_fp32_data_ (nullptr ) {
167174 ORT_ENFORCE (op_kernel_info.GetAttr (" axis" , &axis_).IsOK ());
168175 ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
169176}
@@ -179,8 +186,8 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
179186 const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data <T>();
180187
181188 const TensorShape& x_shape = X->Shape ();
182- size_t scale_size = scale ? static_cast < size_t >( scale->Shape (). Size ()) : prepacked_scale_fp32_size_ ;
183- size_t bias_size = bias ? static_cast < size_t >( bias->Shape (). Size ()) : prepacked_bias_fp32_size_ ;
189+ const TensorShape& scale_shape = scale ? scale->Shape () : prepacked_scale_fp32_shape_ ;
190+ const TensorShape& bias_shape = bias ? bias->Shape () : prepacked_bias_fp32_shape_ ;
184191 Tensor* Y = p_ctx->Output (0 , x_shape);
185192 T* Y_data = Y->MutableData <T>();
186193
@@ -215,7 +222,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
215222
216223 AllocatorPtr alloc;
217224 ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
218- return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size , bias_data, bias_size , Y_data, mean_data,
225+ return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape , bias_data, bias_shape , Y_data, mean_data,
219226 inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
220227}
221228
@@ -234,10 +241,10 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
234241
235242 is_packed = false ;
236243 if (input_idx == 1 ) { // scale
237- prepacked_scale_fp32_size_ = static_cast < size_t >( tensor.Shape (). Size () );
244+ prepacked_scale_fp32_shape_ = tensor.Shape ();
238245 ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_scale_fp32_data_, is_packed);
239246 } else if (input_idx == 2 ) { // bias
240- prepacked_bias_fp32_size_ = static_cast < size_t >( tensor.Shape (). Size () );
247+ prepacked_bias_fp32_shape_ = tensor.Shape ();
241248 ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_bias_fp32_data_, is_packed);
242249 }
243250
@@ -249,9 +256,9 @@ Status LayerNormImpl::ComputeWithoutContext(
249256 const T* X_data,
250257 const TensorShape& x_shape,
251258 const T* scale_data,
252- size_t scale_size ,
259+ const TensorShape& scale_shape ,
253260 const T* bias_data,
254- size_t bias_size ,
261+ const TensorShape& bias_shape ,
255262 T* Y_data,
256263 U* mean_data,
257264 U* inv_std_dev_data,
@@ -260,35 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
260267 float epsilon,
261268 bool simplified,
262269 AllocatorPtr alloc) const {
263- int64_t norm_count = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
264- int64_t norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
265-
266- if (static_cast <int64_t >(scale_size) != norm_size || (bias_data && static_cast <int64_t >(bias_size) != norm_size)) {
267- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
268- " Size of X.shape()[axis:] == " , norm_size,
269- " . Size of scale and bias (if provided) must match this. Got scale size of " ,
270- scale_size, " and bias size of " , bias_size);
271- }
270+ LayerNormParams params;
271+ ORT_RETURN_IF_ERROR (LayerNormHelper::CheckInputs (x_shape, scale_shape, bias_shape, bias_data != nullptr , axis, params));
272272
273273 IAllocatorUniquePtr<float > scale_fp32;
274274 IAllocatorUniquePtr<float > bias_fp32;
275275 if constexpr (std::is_same_v<T, MLFloat16>) {
276276 if (prepacked_scale_fp32_data_ == nullptr ) {
277- const size_t num_elems = static_cast <size_t >(norm_size );
277+ const size_t num_elems = static_cast <size_t >(params. scale_size );
278278 scale_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
279279 MlasConvertHalfToFloatBuffer (scale_data, scale_fp32.get (), num_elems);
280280 }
281281 if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
282- const size_t num_elems = static_cast <size_t >(norm_size );
282+ const size_t num_elems = static_cast <size_t >(params. bias_size );
283283 bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
284284 MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
285285 }
286286 }
287287
288288 concurrency::ThreadPool::TryBatchParallelFor (
289- thread_pool, static_cast <int32_t >(norm_count ),
289+ thread_pool, static_cast <int32_t >(params. num_rows ),
290290 [&](ptrdiff_t task_idx) {
291- ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size,
291+ ComputeJob (X_data, scale_data, bias_data, task_idx, params. norm_size , params. broadcast_param ,
292292 prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get () : scale_fp32.get (),
293293 prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
294294 epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
0 commit comments