Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ static inline LayerNormGadKernelVariant LayerNormGradKernelDispatch(
const DenseTensor* scale,
const DenseTensor* bias) {
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
if (FLAGS_use_accuracy_compatible_kernel) {
return LayerNormGadKernelVariant::GENERIC;
}
if (scale != nullptr && bias != nullptr &&
input_type != paddle::DataType::FLOAT32 && hidden_size != 4096 &&
hidden_size > 1024 && hidden_size <= 10240 &&
Expand Down
40 changes: 3 additions & 37 deletions paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,37 +731,6 @@ void LayerNormFwdCompatKernel(

auto stream = dev_ctx.stream();

// if (!FLAGS_use_accuracy_compatible_kernel && rows <= 1024 &&
// (cols / rows >= 32)) {
// constexpr int num_vec_elems2 = 8;
// constexpr int alignment2 = num_vec_elems2 * sizeof(T);
// bool can_vec_X2 = can_vectorize(x_data, alignment2);
// bool can_vec_Y2 = can_vectorize(y_data, alignment2);
// bool can_vec_gamma2 = can_vectorize(gamma_data, alignment2);
// bool can_vec_beta2 = can_vectorize(beta_data, alignment2);
// bool is_supported_type2 = (std::is_same<T, phi::dtype::float16>::value ||
// std::is_same<T, phi::dtype::bfloat16>::value);
// if (is_supported_type2 &&
// cols <=
// static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits)
// &&
// cols % num_vec_elems2 == 0 && can_vec_X2 && can_vec_Y2 &&
// can_vec_gamma2 && can_vec_beta2) {
// launch_vectorized_layer_norm_kernel_driver<T, T_ACC, 8>(
// cols,
// rows,
// static_cast<T_ACC>(epsilon),
// x_data,
// gamma_data,
// beta_data,
// y_data,
// mean_data,
// var_data,
// stream);
// return;
// }
// }

// Check vectorization conditions for vec_size=4
constexpr int num_vec_elems = 4;
constexpr int alignment = num_vec_elems * sizeof(T);
Expand Down Expand Up @@ -1555,6 +1524,7 @@ __device__ __inline__ void layer_norm_compute_gI(const T* __restrict__ dY,
}

stats_x1 = BlockReduceSum(stats_x1, buf);
__syncthreads();
stats_x2 = BlockReduceSum(stats_x2, buf);
if (threadIdx.x == 0) {
buf[0] = stats_x1;
Expand Down Expand Up @@ -1658,6 +1628,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized(

// Reduction in Shared Memory
stats_x1 = BlockReduceSum(stats_x1, reduce_buf);
__syncthreads();
stats_x2 = BlockReduceSum(stats_x2, reduce_buf);
if (threadIdx.x == 0) {
reduce_buf[0] = stats_x1;
Expand Down Expand Up @@ -2084,12 +2055,7 @@ void LayerNormBwdCompatKernel(
constexpr int num_threads = 128;
constexpr int nshared = (num_threads / kWarpSize) * sizeof(T_ACC);

if (!FLAGS_use_accuracy_compatible_kernel && is_supported_type2 &&
bAlignedBuffers2 && (N % 8 == 0 && M <= 1024 && (N / M >= 32))) {
layer_norm_grad_input_kernel_vectorized<T, T_ACC, 8>
<<<blocks, num_threads, nshared, stream>>>(
dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N);
} else if (is_supported_type && bAlignedBuffers && bVectorSizeMultiple) {
if (is_supported_type && bAlignedBuffers && bVectorSizeMultiple) {
layer_norm_grad_input_kernel_vectorized<T, T_ACC, kVecSize>
<<<blocks, num_threads, nshared, stream>>>(
dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N);
Expand Down
Loading