|
19 | 19 |
|
20 | 20 | namespace phi {
|
21 | 21 |
|
| 22 | +template <typename T, typename TW, typename Context> // TW for scale and bias |
| 23 | +void LayerNormGradImpl(const Context& ctx, |
| 24 | + const DenseTensor& x, |
| 25 | + const paddle::optional<DenseTensor>& scale, |
| 26 | + const paddle::optional<DenseTensor>& bias, |
| 27 | + const DenseTensor& mean, |
| 28 | + const DenseTensor& variance, |
| 29 | + const DenseTensor& out_grad, |
| 30 | + float epsilon, |
| 31 | + int begin_norm_axis, |
| 32 | + DenseTensor* x_grad, |
| 33 | + DenseTensor* scale_grad, |
| 34 | + DenseTensor* bias_grad) { |
| 35 | + const auto* scale_ptr = scale.get_ptr(); |
| 36 | + using XPUType = typename XPUTypeTrait<T>::Type; |
| 37 | + using XPUTypeTW = typename XPUTypeTrait<TW>::Type; |
| 38 | + const auto& x_dims = x.dims(); |
| 39 | + auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis); |
| 40 | + int64_t left = matrix_dim[0]; |
| 41 | + int64_t right = matrix_dim[1]; |
| 42 | + const auto* x_data = x.data<T>(); |
| 43 | + const auto* out_grad_data = out_grad.data<T>(); |
| 44 | + const auto* mean_data = mean.data<float>(); |
| 45 | + const auto* variance_data = variance.data<float>(); |
| 46 | + |
| 47 | + xpu::ctx_guard RAII_GUARD(ctx.x_context()); |
| 48 | + |
| 49 | + T* x_grad_data = nullptr; |
| 50 | + const TW* scale_data = nullptr; |
| 51 | + TW* scale_grad_data = nullptr; |
| 52 | + TW* bias_grad_data = nullptr; |
| 53 | + if (x_grad != nullptr) { |
| 54 | + ctx.template Alloc<T>(x_grad); |
| 55 | + x_grad_data = x_grad->data<T>(); |
| 56 | + } |
| 57 | + if (scale_ptr != nullptr) { |
| 58 | + scale_data = scale_ptr->data<TW>(); |
| 59 | + if (scale_grad != nullptr) { |
| 60 | + ctx.template Alloc<TW>(scale_grad); |
| 61 | + scale_grad_data = scale_grad->data<TW>(); |
| 62 | + } |
| 63 | + } |
| 64 | + if (bias_grad != nullptr) { |
| 65 | + ctx.template Alloc<TW>(bias_grad); |
| 66 | + bias_grad_data = bias_grad->data<TW>(); |
| 67 | + } |
| 68 | + int r = xpu::layer_norm_grad(ctx.x_context(), |
| 69 | + reinterpret_cast<const XPUType*>(x_data), |
| 70 | + reinterpret_cast<const XPUType*>(out_grad_data), |
| 71 | + reinterpret_cast<XPUType*>(x_grad_data), |
| 72 | + left, |
| 73 | + right, |
| 74 | + epsilon, |
| 75 | + reinterpret_cast<const XPUTypeTW*>(scale_data), |
| 76 | + mean_data, |
| 77 | + variance_data, |
| 78 | + reinterpret_cast<XPUTypeTW*>(scale_grad_data), |
| 79 | + reinterpret_cast<XPUTypeTW*>(bias_grad_data)); |
| 80 | + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); |
| 81 | +} |
| 82 | + |
22 | 83 | template <typename T, typename Context>
|
23 | 84 | void LayerNormGradKernel(const Context& ctx,
|
24 | 85 | const DenseTensor& x,
|
@@ -46,137 +107,40 @@ void LayerNormGradKernel(const Context& ctx,
|
46 | 107 | }
|
47 | 108 | }
|
48 | 109 |
|
49 |
| - bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; |
| 110 | + bool is_scale_bias_same_dtype_with_x = (x_dtype == scale_bias_dtype); |
50 | 111 | if (!is_scale_bias_same_dtype_with_x) {
|
51 | 112 | PADDLE_ENFORCE_EQ(scale_bias_dtype,
|
52 | 113 | phi::CppTypeToDataType<float>::Type(),
|
53 | 114 | common::errors::InvalidArgument(
|
54 | 115 | "Unsupported data type of Scale and Bias"));
|
55 | 116 | }
|
56 |
| - using XPUType = typename XPUTypeTrait<T>::Type; |
57 |
| - const auto& x_dims = x.dims(); |
58 |
| - auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis); |
59 |
| - int left = static_cast<int>(matrix_dim[0]); |
60 |
| - int right = static_cast<int>(matrix_dim[1]); |
61 |
| - const auto* x_data = x.data<T>(); |
62 |
| - const auto* out_grad_data = out_grad.data<T>(); |
63 |
| - const auto* mean_data = mean.data<float>(); |
64 |
| - const auto* variance_data = variance.data<float>(); |
65 |
| - |
66 |
| - xpu::ctx_guard RAII_GUARD(ctx.x_context()); |
67 |
| - |
68 |
| - // scale |
69 |
| - const float* scale_data_fp32 = nullptr; |
70 |
| - float* scale_grad_data_fp32 = nullptr; |
71 |
| - const T* scale_data_T = nullptr; |
72 |
| - T* scale_grad_data_T = nullptr; |
73 |
| - bool need_cast_scale = false; |
74 |
| - if (scale_ptr == nullptr) { |
75 |
| - // no scale, do nothing |
76 |
| - } else if (scale_ptr->dtype() == |
77 |
| - phi::CppTypeToDataType<phi::dtype::float16>::Type()) { |
78 |
| - float* scale_data_temp = |
79 |
| - RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel()); |
80 |
| - int r = xpu::cast<XPUType, float>( |
81 |
| - ctx.x_context(), |
82 |
| - reinterpret_cast<const XPUType*>(scale_ptr->data<T>()), |
83 |
| - scale_data_temp, |
84 |
| - scale_ptr->numel()); |
85 |
| - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); |
86 |
| - scale_data_fp32 = scale_data_temp; |
87 |
| - need_cast_scale = true; |
88 |
| - scale_grad_data_fp32 = |
89 |
| - scale_grad == nullptr |
90 |
| - ? nullptr |
91 |
| - : RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel()); |
92 |
| - } else { |
93 |
| - // no need to cast |
94 |
| - if (is_scale_bias_same_dtype_with_x) { |
95 |
| - scale_data_T = scale_ptr->data<T>(); |
96 |
| - scale_grad_data_T = |
97 |
| - scale_grad == nullptr ? nullptr : ctx.template Alloc<T>(scale_grad); |
98 |
| - } else { |
99 |
| - scale_data_fp32 = scale_ptr->data<float>(); |
100 |
| - scale_grad_data_fp32 = scale_grad == nullptr |
101 |
| - ? nullptr |
102 |
| - : ctx.template Alloc<float>(scale_grad); |
103 |
| - } |
104 |
| - } |
105 | 117 |
|
106 |
| - // bias |
107 |
| - float* bias_grad_data_fp32 = nullptr; |
108 |
| - T* bias_grad_data_T = nullptr; |
109 |
| - bool need_cast_bias = false; |
110 |
| - if (bias_ptr == nullptr) { |
111 |
| - // no bias, do nothing |
112 |
| - } else if (bias_ptr->dtype() == |
113 |
| - phi::CppTypeToDataType<phi::dtype::float16>::Type()) { |
114 |
| - need_cast_bias = true; |
115 |
| - bias_grad_data_fp32 = |
116 |
| - bias_grad == nullptr |
117 |
| - ? nullptr |
118 |
| - : RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel()); |
| 118 | + if (is_scale_bias_same_dtype_with_x) { |
| 119 | + LayerNormGradImpl<T, T, Context>(ctx, |
| 120 | + x, |
| 121 | + scale, |
| 122 | + bias, |
| 123 | + mean, |
| 124 | + variance, |
| 125 | + out_grad, |
| 126 | + epsilon, |
| 127 | + begin_norm_axis, |
| 128 | + x_grad, |
| 129 | + scale_grad, |
| 130 | + bias_grad); |
119 | 131 | } else {
|
120 |
| - // no need to cast |
121 |
| - if (is_scale_bias_same_dtype_with_x) { |
122 |
| - bias_grad_data_T = |
123 |
| - bias_grad == nullptr ? nullptr : ctx.template Alloc<T>(bias_grad); |
124 |
| - } else { |
125 |
| - bias_grad_data_fp32 = |
126 |
| - bias_grad == nullptr ? nullptr : ctx.template Alloc<float>(bias_grad); |
127 |
| - } |
128 |
| - } |
129 |
| - |
130 |
| - auto* x_grad_data = |
131 |
| - (x_grad == nullptr ? nullptr : ctx.template Alloc<T>(x_grad)); |
132 |
| - |
133 |
| - if (!is_scale_bias_same_dtype_with_x) { |
134 |
| - int r = |
135 |
| - xpu::layer_norm_grad(ctx.x_context(), |
136 |
| - reinterpret_cast<const XPUType*>(x_data), |
137 |
| - reinterpret_cast<const XPUType*>(out_grad_data), |
138 |
| - reinterpret_cast<XPUType*>(x_grad_data), |
139 |
| - left, |
140 |
| - right, |
141 |
| - epsilon, |
142 |
| - scale_data_fp32, |
143 |
| - mean_data, |
144 |
| - variance_data, |
145 |
| - scale_grad_data_fp32, |
146 |
| - bias_grad_data_fp32); |
147 |
| - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); |
148 |
| - } else { |
149 |
| - int r = |
150 |
| - xpu::layer_norm_grad(ctx.x_context(), |
151 |
| - reinterpret_cast<const XPUType*>(x_data), |
152 |
| - reinterpret_cast<const XPUType*>(out_grad_data), |
153 |
| - reinterpret_cast<XPUType*>(x_grad_data), |
154 |
| - left, |
155 |
| - right, |
156 |
| - epsilon, |
157 |
| - reinterpret_cast<const XPUType*>(scale_data_T), |
158 |
| - mean_data, |
159 |
| - variance_data, |
160 |
| - reinterpret_cast<XPUType*>(scale_grad_data_T), |
161 |
| - reinterpret_cast<XPUType*>(bias_grad_data_T)); |
162 |
| - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); |
163 |
| - } |
164 |
| - |
165 |
| - if (need_cast_scale) { |
166 |
| - int r = xpu::cast<float, XPUType>( |
167 |
| - ctx.x_context(), |
168 |
| - scale_grad_data_fp32, |
169 |
| - reinterpret_cast<XPUType*>(ctx.template Alloc<T>(scale_grad)), |
170 |
| - scale.get_ptr()->numel()); |
171 |
| - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); |
172 |
| - } |
173 |
| - if (need_cast_bias) { |
174 |
| - int r = xpu::cast<float, XPUType>( |
175 |
| - ctx.x_context(), |
176 |
| - bias_grad_data_fp32, |
177 |
| - reinterpret_cast<XPUType*>(ctx.template Alloc<T>(bias_grad)), |
178 |
| - bias.get_ptr()->numel()); |
179 |
| - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); |
| 132 | + LayerNormGradImpl<T, float, Context>(ctx, |
| 133 | + x, |
| 134 | + scale, |
| 135 | + bias, |
| 136 | + mean, |
| 137 | + variance, |
| 138 | + out_grad, |
| 139 | + epsilon, |
| 140 | + begin_norm_axis, |
| 141 | + x_grad, |
| 142 | + scale_grad, |
| 143 | + bias_grad); |
180 | 144 | }
|
181 | 145 | }
|
182 | 146 | } // namespace phi
|
|
0 commit comments