@@ -57,16 +57,8 @@ void ComputeJob(
5757 mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
5858 }
5959
60- // When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
61- // We support scale and bias shape like below:
62- // When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
63- // When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
64- // When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
65- // When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
66- // Here we compute the initial index for scale and bias data.
67- int64_t i = (broadcast_param == 0 )
68- ? 0
69- : norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
60+ // Compute the offset of gamma and beta to support broadcasting.
61+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
7062
7163 for (int64_t h = 0 ; h < norm_size; h++, i++) {
7264 if (simplified) {
@@ -134,16 +126,8 @@ void ComputeJob(
134126 mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
135127 }
136128
137- // When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
138- // We support scale and bias shape like below:
139- // When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
140- // When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
141- // When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
142- // When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
143- // Here we compute the initial index for scale and bias data.
144- int64_t i = (broadcast_param == 0 )
145- ? 0
146- : norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
129+ // Compute the offset of gamma and beta to support broadcasting.
130+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
147131
148132 for (size_t h = 0 ; h < num_elems; h++, i++) {
149133 if (simplified) {
@@ -283,38 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
283267 float epsilon,
284268 bool simplified,
285269 AllocatorPtr alloc) const {
286- int64_t norm_count = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
287- int64_t norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
288-
289- int64_t scale_size = scale_shape.Size ();
290- int64_t bias_size = bias_shape.Size ();
291- int64_t broadcast_param = 0 ;
292-
293- if (norm_size <= 1 ) {
294- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize , norm_size);
295- } else if (static_cast <int64_t >(scale_size) != norm_size || (bias_data && static_cast <int64_t >(bias_size) != norm_size)) {
296- ORT_RETURN_IF_ERROR (LayerNormHelper::CheckBroadcast (x_shape, scale_shape, bias_shape, bias_data != nullptr , axis, broadcast_param));
297- }
270+ LayerNormParams params;
271+ ORT_RETURN_IF_ERROR (LayerNormHelper::CheckInputs (x_shape, scale_shape, bias_shape, bias_data != nullptr , axis, params));
298272
299273 IAllocatorUniquePtr<float > scale_fp32;
300274 IAllocatorUniquePtr<float > bias_fp32;
301275 if constexpr (std::is_same_v<T, MLFloat16>) {
302276 if (prepacked_scale_fp32_data_ == nullptr ) {
303- const size_t num_elems = static_cast <size_t >(scale_size);
277+ const size_t num_elems = static_cast <size_t >(params. scale_size );
304278 scale_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
305279 MlasConvertHalfToFloatBuffer (scale_data, scale_fp32.get (), num_elems);
306280 }
307281 if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
308- const size_t num_elems = static_cast <size_t >(bias_size);
282+ const size_t num_elems = static_cast <size_t >(params. bias_size );
309283 bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
310284 MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
311285 }
312286 }
313287
314288 concurrency::ThreadPool::TryBatchParallelFor (
315- thread_pool, static_cast <int32_t >(norm_count ),
289+ thread_pool, static_cast <int32_t >(params. num_rows ),
316290 [&](ptrdiff_t task_idx) {
317- ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size, broadcast_param,
291+ ComputeJob (X_data, scale_data, bias_data, task_idx, params. norm_size , params. broadcast_param ,
318292 prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get () : scale_fp32.get (),
319293 prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
320294 epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
0 commit comments