@@ -44,36 +44,19 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
4444 auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast <const CudaV*>(bias->Data <V>());
4545
4646 const TensorShape& x_shape = X->Shape ();
47- auto x_num_dims = x_shape.NumDimensions ();
48- const int64_t axis = HandleNegativeAxis (axis_, x_num_dims);
47+ const int64_t axis = HandleNegativeAxis (axis_, x_shape.NumDimensions ());
4948
5049 int n1 = gsl::narrow<int >(x_shape.SizeToDimension (axis));
5150 int n2 = gsl::narrow<int >(x_shape.SizeFromDimension (axis));
5251
5352 const auto scale_size = scale->Shape ().Size ();
5453 const auto bias_size = (bias_data) ? bias->Shape ().Size () : 0 ;
55-
56- int broadcast = 0 ;
5754 if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
58- // Handle a special case for MMDit where scale and bias need broadcast.
59- // X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride.
60- if (x_num_dims == 3 && axis == 2 && n2 > 1 &&
61- scale->Shape ().NumDimensions () == x_num_dims &&
62- scale->Shape ().GetDims ()[0 ] == x_shape.GetDims ()[0 ] &&
63- scale->Shape ().GetDims ()[1 ] == 1 &&
64- scale->Shape ().GetDims ()[2 ] == x_shape.GetDims ()[2 ] &&
65- bias->Shape ().NumDimensions () == x_num_dims &&
66- bias->Shape ().GetDims ()[0 ] == x_shape.GetDims ()[0 ] &&
67- bias->Shape ().GetDims ()[1 ] == 1 &&
68- bias->Shape ().GetDims ()[2 ] == x_shape.GetDims ()[2 ]) {
69- broadcast = static_cast <int >(x_shape.GetDims ()[1 ]);
70- } else {
71- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
72- " Size of X.shape()[axis:] == " , n2,
73- " . Size of scale and bias (if provided) must match this "
74- " and the size must not be 1. Got scale size of " ,
75- scale_size, " and bias size of " , bias_size);
76- }
55+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
56+ " Size of X.shape()[axis:] == " , n2,
57+ " . Size of scale and bias (if provided) must match this "
58+ " and the size must not be 1. Got scale size of " ,
59+ scale_size, " and bias size of " , bias_size);
7760 }
7861
7962 // Outputs
@@ -82,7 +65,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
8265
8366 // Mean and variance
8467 std::vector<int64_t > mean_inv_std_var_dim;
85- for (int i = 0 ; i < static_cast <int >(x_num_dims ); ++i) {
68+ for (int i = 0 ; i < static_cast <int >(x_shape. NumDimensions () ); ++i) {
8669 if (i < axis) {
8770 mean_inv_std_var_dim.emplace_back (x_shape.GetDims ()[i]);
8871 } else {
@@ -111,7 +94,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
11194 }
11295
11396 HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp (), Stream (ctx), Y_data, mean_data, inv_var_data,
114- X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast );
97+ X_data, n1, n2, epsilon_, scale_data, bias_data);
11598 CUDA_RETURN_IF_ERROR (cudaGetLastError ());
11699 return Status::OK ();
117100}
0 commit comments