Skip to content

Commit 40be182

Browse files
authored
[XPU] improve performance of layernorm (PaddlePaddle#72478)
1 parent 7fc1a28 commit 40be182

File tree

2 files changed

+132
-208
lines changed

2 files changed

+132
-208
lines changed

paddle/phi/kernels/xpu/layer_norm_grad_kernel.cc

+87-123
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,67 @@
1919

2020
namespace phi {
2121

22+
template <typename T, typename TW, typename Context> // TW for scale and bias
23+
void LayerNormGradImpl(const Context& ctx,
24+
const DenseTensor& x,
25+
const paddle::optional<DenseTensor>& scale,
26+
const paddle::optional<DenseTensor>& bias,
27+
const DenseTensor& mean,
28+
const DenseTensor& variance,
29+
const DenseTensor& out_grad,
30+
float epsilon,
31+
int begin_norm_axis,
32+
DenseTensor* x_grad,
33+
DenseTensor* scale_grad,
34+
DenseTensor* bias_grad) {
35+
const auto* scale_ptr = scale.get_ptr();
36+
using XPUType = typename XPUTypeTrait<T>::Type;
37+
using XPUTypeTW = typename XPUTypeTrait<TW>::Type;
38+
const auto& x_dims = x.dims();
39+
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
40+
int64_t left = matrix_dim[0];
41+
int64_t right = matrix_dim[1];
42+
const auto* x_data = x.data<T>();
43+
const auto* out_grad_data = out_grad.data<T>();
44+
const auto* mean_data = mean.data<float>();
45+
const auto* variance_data = variance.data<float>();
46+
47+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
48+
49+
T* x_grad_data = nullptr;
50+
const TW* scale_data = nullptr;
51+
TW* scale_grad_data = nullptr;
52+
TW* bias_grad_data = nullptr;
53+
if (x_grad != nullptr) {
54+
ctx.template Alloc<T>(x_grad);
55+
x_grad_data = x_grad->data<T>();
56+
}
57+
if (scale_ptr != nullptr) {
58+
scale_data = scale_ptr->data<TW>();
59+
if (scale_grad != nullptr) {
60+
ctx.template Alloc<TW>(scale_grad);
61+
scale_grad_data = scale_grad->data<TW>();
62+
}
63+
}
64+
if (bias_grad != nullptr) {
65+
ctx.template Alloc<TW>(bias_grad);
66+
bias_grad_data = bias_grad->data<TW>();
67+
}
68+
int r = xpu::layer_norm_grad(ctx.x_context(),
69+
reinterpret_cast<const XPUType*>(x_data),
70+
reinterpret_cast<const XPUType*>(out_grad_data),
71+
reinterpret_cast<XPUType*>(x_grad_data),
72+
left,
73+
right,
74+
epsilon,
75+
reinterpret_cast<const XPUTypeTW*>(scale_data),
76+
mean_data,
77+
variance_data,
78+
reinterpret_cast<XPUTypeTW*>(scale_grad_data),
79+
reinterpret_cast<XPUTypeTW*>(bias_grad_data));
80+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
81+
}
82+
2283
template <typename T, typename Context>
2384
void LayerNormGradKernel(const Context& ctx,
2485
const DenseTensor& x,
@@ -46,137 +107,40 @@ void LayerNormGradKernel(const Context& ctx,
46107
}
47108
}
48109

49-
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
110+
bool is_scale_bias_same_dtype_with_x = (x_dtype == scale_bias_dtype);
50111
if (!is_scale_bias_same_dtype_with_x) {
51112
PADDLE_ENFORCE_EQ(scale_bias_dtype,
52113
phi::CppTypeToDataType<float>::Type(),
53114
common::errors::InvalidArgument(
54115
"Unsupported data type of Scale and Bias"));
55116
}
56-
using XPUType = typename XPUTypeTrait<T>::Type;
57-
const auto& x_dims = x.dims();
58-
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
59-
int left = static_cast<int>(matrix_dim[0]);
60-
int right = static_cast<int>(matrix_dim[1]);
61-
const auto* x_data = x.data<T>();
62-
const auto* out_grad_data = out_grad.data<T>();
63-
const auto* mean_data = mean.data<float>();
64-
const auto* variance_data = variance.data<float>();
65-
66-
xpu::ctx_guard RAII_GUARD(ctx.x_context());
67-
68-
// scale
69-
const float* scale_data_fp32 = nullptr;
70-
float* scale_grad_data_fp32 = nullptr;
71-
const T* scale_data_T = nullptr;
72-
T* scale_grad_data_T = nullptr;
73-
bool need_cast_scale = false;
74-
if (scale_ptr == nullptr) {
75-
// no scale, do nothing
76-
} else if (scale_ptr->dtype() ==
77-
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
78-
float* scale_data_temp =
79-
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
80-
int r = xpu::cast<XPUType, float>(
81-
ctx.x_context(),
82-
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
83-
scale_data_temp,
84-
scale_ptr->numel());
85-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
86-
scale_data_fp32 = scale_data_temp;
87-
need_cast_scale = true;
88-
scale_grad_data_fp32 =
89-
scale_grad == nullptr
90-
? nullptr
91-
: RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
92-
} else {
93-
// no need to cast
94-
if (is_scale_bias_same_dtype_with_x) {
95-
scale_data_T = scale_ptr->data<T>();
96-
scale_grad_data_T =
97-
scale_grad == nullptr ? nullptr : ctx.template Alloc<T>(scale_grad);
98-
} else {
99-
scale_data_fp32 = scale_ptr->data<float>();
100-
scale_grad_data_fp32 = scale_grad == nullptr
101-
? nullptr
102-
: ctx.template Alloc<float>(scale_grad);
103-
}
104-
}
105117

106-
// bias
107-
float* bias_grad_data_fp32 = nullptr;
108-
T* bias_grad_data_T = nullptr;
109-
bool need_cast_bias = false;
110-
if (bias_ptr == nullptr) {
111-
// no bias, do nothing
112-
} else if (bias_ptr->dtype() ==
113-
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
114-
need_cast_bias = true;
115-
bias_grad_data_fp32 =
116-
bias_grad == nullptr
117-
? nullptr
118-
: RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
118+
if (is_scale_bias_same_dtype_with_x) {
119+
LayerNormGradImpl<T, T, Context>(ctx,
120+
x,
121+
scale,
122+
bias,
123+
mean,
124+
variance,
125+
out_grad,
126+
epsilon,
127+
begin_norm_axis,
128+
x_grad,
129+
scale_grad,
130+
bias_grad);
119131
} else {
120-
// no need to cast
121-
if (is_scale_bias_same_dtype_with_x) {
122-
bias_grad_data_T =
123-
bias_grad == nullptr ? nullptr : ctx.template Alloc<T>(bias_grad);
124-
} else {
125-
bias_grad_data_fp32 =
126-
bias_grad == nullptr ? nullptr : ctx.template Alloc<float>(bias_grad);
127-
}
128-
}
129-
130-
auto* x_grad_data =
131-
(x_grad == nullptr ? nullptr : ctx.template Alloc<T>(x_grad));
132-
133-
if (!is_scale_bias_same_dtype_with_x) {
134-
int r =
135-
xpu::layer_norm_grad(ctx.x_context(),
136-
reinterpret_cast<const XPUType*>(x_data),
137-
reinterpret_cast<const XPUType*>(out_grad_data),
138-
reinterpret_cast<XPUType*>(x_grad_data),
139-
left,
140-
right,
141-
epsilon,
142-
scale_data_fp32,
143-
mean_data,
144-
variance_data,
145-
scale_grad_data_fp32,
146-
bias_grad_data_fp32);
147-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
148-
} else {
149-
int r =
150-
xpu::layer_norm_grad(ctx.x_context(),
151-
reinterpret_cast<const XPUType*>(x_data),
152-
reinterpret_cast<const XPUType*>(out_grad_data),
153-
reinterpret_cast<XPUType*>(x_grad_data),
154-
left,
155-
right,
156-
epsilon,
157-
reinterpret_cast<const XPUType*>(scale_data_T),
158-
mean_data,
159-
variance_data,
160-
reinterpret_cast<XPUType*>(scale_grad_data_T),
161-
reinterpret_cast<XPUType*>(bias_grad_data_T));
162-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
163-
}
164-
165-
if (need_cast_scale) {
166-
int r = xpu::cast<float, XPUType>(
167-
ctx.x_context(),
168-
scale_grad_data_fp32,
169-
reinterpret_cast<XPUType*>(ctx.template Alloc<T>(scale_grad)),
170-
scale.get_ptr()->numel());
171-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
172-
}
173-
if (need_cast_bias) {
174-
int r = xpu::cast<float, XPUType>(
175-
ctx.x_context(),
176-
bias_grad_data_fp32,
177-
reinterpret_cast<XPUType*>(ctx.template Alloc<T>(bias_grad)),
178-
bias.get_ptr()->numel());
179-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
132+
LayerNormGradImpl<T, float, Context>(ctx,
133+
x,
134+
scale,
135+
bias,
136+
mean,
137+
variance,
138+
out_grad,
139+
epsilon,
140+
begin_norm_axis,
141+
x_grad,
142+
scale_grad,
143+
bias_grad);
180144
}
181145
}
182146
} // namespace phi

paddle/phi/kernels/xpu/layer_norm_kernel.cc

+45-85
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,44 @@
1919

2020
namespace phi {
2121

22+
template <typename T, typename TW, typename Context>
23+
void LayerNormKernelImpl(const Context& ctx,
24+
const DenseTensor& x,
25+
const paddle::optional<DenseTensor>& scale,
26+
const paddle::optional<DenseTensor>& bias,
27+
float epsilon,
28+
int begin_norm_axis,
29+
DenseTensor* out,
30+
DenseTensor* mean,
31+
DenseTensor* variance) {
32+
using XPUType = typename XPUTypeTrait<T>::Type;
33+
using XPUTypeTW = typename XPUTypeTrait<TW>::Type;
34+
const auto& x_dims = x.dims();
35+
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
36+
int64_t left = matrix_dim[0];
37+
int64_t right = matrix_dim[1];
38+
39+
const auto* x_data = x.data<T>();
40+
const auto* scale_data = scale.get_ptr() ? scale->data<TW>() : nullptr;
41+
const auto* bias_data = bias.get_ptr() ? bias->data<TW>() : nullptr;
42+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
43+
auto* out_data = ctx.template Alloc<T>(out);
44+
auto* mean_data = ctx.template Alloc<float>(mean);
45+
auto* variance_data = ctx.template Alloc<float>(variance);
46+
47+
int r = xpu::layer_norm(ctx.x_context(),
48+
reinterpret_cast<const XPUType*>(x_data),
49+
reinterpret_cast<XPUType*>(out_data),
50+
left,
51+
right,
52+
epsilon,
53+
reinterpret_cast<const XPUTypeTW*>(scale_data),
54+
reinterpret_cast<const XPUTypeTW*>(bias_data),
55+
mean_data,
56+
variance_data);
57+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
58+
}
59+
2260
template <typename T, typename Context>
2361
void LayerNormKernel(const Context& ctx,
2462
const DenseTensor& x,
@@ -31,8 +69,6 @@ void LayerNormKernel(const Context& ctx,
3169
DenseTensor* variance) {
3270
bool valid_scale = (scale.get_ptr() != nullptr);
3371
bool valid_bias = (bias.get_ptr() != nullptr);
34-
auto* void_scale_data = valid_scale ? scale->data() : nullptr;
35-
auto* void_bias_data = valid_bias ? bias->data() : nullptr;
3672

3773
auto x_dtype = x.dtype();
3874
phi::DataType scale_bias_dtype;
@@ -49,96 +85,20 @@ void LayerNormKernel(const Context& ctx,
4985
scale_bias_dtype = valid_bias ? bias->dtype() : x_dtype;
5086
}
5187

52-
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
88+
bool is_scale_bias_same_dtype_with_x = (x_dtype == scale_bias_dtype);
5389
if (!is_scale_bias_same_dtype_with_x) {
5490
PADDLE_ENFORCE_EQ(scale_bias_dtype,
55-
phi::CppTypeToDataType<float>::Type(),
91+
phi::DataType::FLOAT32,
5692
common::errors::InvalidArgument(
5793
"Unsupported data type of Scale and Bias"));
5894
}
5995

60-
using XPUType = typename XPUTypeTrait<T>::Type;
61-
const auto& x_dims = x.dims();
62-
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
63-
int left = static_cast<int>(matrix_dim[0]);
64-
int right = static_cast<int>(matrix_dim[1]);
65-
const auto* x_data = x.data<T>();
66-
67-
xpu::ctx_guard RAII_GUARD(ctx.x_context());
68-
69-
// scale
70-
const float* scale_data_fp32 = nullptr;
71-
const auto* scale_ptr = scale.get_ptr();
72-
if (scale_ptr == nullptr) {
73-
// no scale, do nothing
74-
} else if (scale_ptr->dtype() ==
75-
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
76-
float* scale_data_temp =
77-
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
78-
int r = xpu::cast<XPUType, float>(
79-
ctx.x_context(),
80-
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
81-
scale_data_temp,
82-
scale_ptr->numel());
83-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
84-
scale_data_fp32 = scale_data_temp;
85-
} else {
86-
// no need to cast
87-
if (!is_scale_bias_same_dtype_with_x) {
88-
scale_data_fp32 = scale_ptr->data<float>();
89-
}
90-
}
91-
92-
// bias
93-
const float* bias_data_fp32 = nullptr;
94-
const auto* bias_ptr = bias.get_ptr();
95-
if (bias_ptr == nullptr) {
96-
// no bias, do nothing
97-
} else if (bias_ptr->dtype() ==
98-
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
99-
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
100-
int r = xpu::cast<XPUType, float>(
101-
ctx.x_context(),
102-
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
103-
bias_data_temp,
104-
bias_ptr->numel());
105-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
106-
bias_data_fp32 = bias_data_temp;
107-
} else {
108-
// no need to cast
109-
if (!is_scale_bias_same_dtype_with_x) {
110-
bias_data_fp32 = bias_ptr->data<float>();
111-
}
112-
}
113-
114-
auto* out_data = ctx.template Alloc<T>(out);
115-
auto* mean_data = ctx.template Alloc<float>(mean);
116-
auto* variance_data = ctx.template Alloc<float>(variance);
117-
118-
if (!is_scale_bias_same_dtype_with_x) {
119-
int r = xpu::layer_norm(ctx.x_context(),
120-
reinterpret_cast<const XPUType*>(x_data),
121-
reinterpret_cast<XPUType*>(out_data),
122-
left,
123-
right,
124-
epsilon,
125-
scale_data_fp32,
126-
bias_data_fp32,
127-
mean_data,
128-
variance_data);
129-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
96+
if (is_scale_bias_same_dtype_with_x) {
97+
LayerNormKernelImpl<T, T, Context>(
98+
ctx, x, scale, bias, epsilon, begin_norm_axis, out, mean, variance);
13099
} else {
131-
int r = xpu::layer_norm(ctx.x_context(),
132-
reinterpret_cast<const XPUType*>(x_data),
133-
reinterpret_cast<XPUType*>(out_data),
134-
left,
135-
right,
136-
epsilon,
137-
reinterpret_cast<const XPUType*>(void_scale_data),
138-
reinterpret_cast<const XPUType*>(void_bias_data),
139-
mean_data,
140-
variance_data);
141-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
100+
LayerNormKernelImpl<T, float, Context>(
101+
ctx, x, scale, bias, epsilon, begin_norm_axis, out, mean, variance);
142102
}
143103
}
144104
} // namespace phi

0 commit comments

Comments
 (0)