Skip to content

Commit a47b6af

Browse files
committed
undo layer norm kernel
1 parent 913c6ed commit a47b6af

File tree

5 files changed

+13
-48
lines changed

5 files changed

+13
-48
lines changed

onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
101101
(double)epsilon_, // epsilon
102102
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
103103
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
104-
0, // broadcast stride for gamma/beta
105104
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
106105
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
107106
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);

onnxruntime/core/providers/cuda/nn/layer_norm.cc

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,36 +44,19 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
4444
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());
4545

4646
const TensorShape& x_shape = X->Shape();
47-
auto x_num_dims = x_shape.NumDimensions();
48-
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);
47+
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
4948

5049
int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
5150
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));
5251

5352
const auto scale_size = scale->Shape().Size();
5453
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
55-
56-
int broadcast = 0;
5754
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
58-
// Handle a special case for MMDit where scale and bias need broadcast.
59-
// X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride.
60-
if (x_num_dims == 3 && axis == 2 && n2 > 1 &&
61-
scale->Shape().NumDimensions() == x_num_dims &&
62-
scale->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
63-
scale->Shape().GetDims()[1] == 1 &&
64-
scale->Shape().GetDims()[2] == x_shape.GetDims()[2] &&
65-
bias->Shape().NumDimensions() == x_num_dims &&
66-
bias->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
67-
bias->Shape().GetDims()[1] == 1 &&
68-
bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) {
69-
broadcast = static_cast<int>(x_shape.GetDims()[1]);
70-
} else {
71-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
72-
"Size of X.shape()[axis:] == ", n2,
73-
". Size of scale and bias (if provided) must match this "
74-
"and the size must not be 1. Got scale size of ",
75-
scale_size, " and bias size of ", bias_size);
76-
}
55+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
56+
"Size of X.shape()[axis:] == ", n2,
57+
". Size of scale and bias (if provided) must match this "
58+
"and the size must not be 1. Got scale size of ",
59+
scale_size, " and bias size of ", bias_size);
7760
}
7861

7962
// Outputs
@@ -82,7 +65,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
8265

8366
// Mean and variance
8467
std::vector<int64_t> mean_inv_std_var_dim;
85-
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
68+
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
8669
if (i < axis) {
8770
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
8871
} else {
@@ -111,7 +94,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
11194
}
11295

11396
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
114-
X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast);
97+
X_data, n1, n2, epsilon_, scale_data, bias_data);
11598
CUDA_RETURN_IF_ERROR(cudaGetLastError());
11699
return Status::OK();
117100
}

onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ __global__ void cuApplyLayerNorm(
334334
const U epsilon,
335335
const V* __restrict__ gamma,
336336
const V* __restrict__ beta,
337-
int broadcast,
338337
const T* __restrict__ skip,
339338
const T* __restrict__ bias,
340339
T* __restrict__ skip_input_bias_add_output) {
@@ -367,13 +366,8 @@ __global__ void cuApplyLayerNorm(
367366
curr += static_cast<U>(skip_vals[i]);
368367
}
369368

370-
// onnx operator LayerNormalization support broadcast.
371-
// gamma and beta should be unidirectional broadcastable to tensor x.
372-
// Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D)
373-
int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i;
374-
U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1;
375-
U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0;
376-
369+
U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1;
370+
U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0;
377371
if (simplified) {
378372
ovals[i] = static_cast<V>(gamma_i * c_inv_std_dev * curr);
379373
} else {
@@ -415,7 +409,6 @@ void HostApplyLayerNorm(
415409
double epsilon,
416410
const V* gamma,
417411
const V* beta,
418-
int broadcast,
419412
const T* skip,
420413
const T* bias,
421414
T* skip_input_bias_add_output) {
@@ -449,15 +442,15 @@ void HostApplyLayerNorm(
449442
input,
450443
n1, n2,
451444
U(epsilon),
452-
gamma, beta, broadcast,
445+
gamma, beta,
453446
skip, bias, skip_input_bias_add_output);
454447
}
455448

456449
#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
457450
template void HostApplyLayerNorm<T, U, V, simplified>(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
458451
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
459-
double epsilon, const V* gamma, const V* beta, int broadcast, \
460-
const T* skip, const T* bias, T* skip_input_bias_add_output);
452+
double epsilon, const V* gamma, const V* beta, const T* skip, \
453+
const T* bias, T* skip_input_bias_add_output);
461454

462455
LAYERNORM_LINEAR_IMPL(float, float, float, true)
463456
LAYERNORM_LINEAR_IMPL(half, float, half, true)

onnxruntime/core/providers/cuda/nn/layer_norm_impl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ void HostApplyLayerNorm(
4141
double epsilon,
4242
const V* gamma,
4343
const V* beta,
44-
int broadcast = 0, // broadcast stride for gamma/beta
4544
const T* skip = nullptr,
4645
const T* bias = nullptr,
4746
T* skip_input_bias_add_output = nullptr);

onnxruntime/python/tools/transformers/onnx_model_mmdit.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,6 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None):
8787
if progress_bar:
8888
progress_bar.update(1)
8989

90-
# TODO: SkipLayerNormalization does not support broadcast yet.
91-
# if (options is None) or options.enable_skip_layer_norm:
92-
# self.fuse_skip_simplified_layer_norm()
93-
# self.fuse_skip_layer_norm()
94-
# if (options is None) or options.enable_bias_skip_layer_norm:
95-
# # Fuse SkipLayerNormalization and Add Bias before it.
96-
# self.fuse_add_bias_skip_layer_norm()
97-
9890
self.postprocess()
9991
if progress_bar:
10092
progress_bar.update(1)
@@ -110,7 +102,6 @@ def get_fused_operator_statistics(self):
110102
"FastGelu",
111103
"MultiHeadAttention",
112104
"LayerNormalization",
113-
# "SkipLayerNormalization",
114105
"SimplifiedLayerNormalization",
115106
]
116107

0 commit comments

Comments
 (0)