|
18 | 18 |
|
19 | 19 | #include "paddle/phi/backends/xpu/xpu_context.h" |
20 | 20 | #include "paddle/phi/core/kernel_registry.h" |
| 21 | +#ifdef PADDLE_WITH_XPU_FFT |
| 22 | +#include "fft/cuComplex.h" |
| 23 | +#include "paddle/phi/kernels/complex_kernel.h" |
| 24 | +#include "paddle/phi/kernels/expand_kernel.h" |
| 25 | +#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h" |
| 26 | +namespace xfft_internal::xpu { |
| 27 | +int RemainderFunctor(int N, float2* input_x, float2* input_y, float2* output); |
| 28 | +} |
| 29 | +#endif |
21 | 30 |
|
22 | 31 | namespace phi { |
23 | 32 |
|
@@ -75,6 +84,69 @@ void ElementwisePowKernel(const Context& dev_ctx, |
75 | 84 | ElementwisePowRawKernel<T>(dev_ctx, x, y, axis, out); |
76 | 85 | } |
77 | 86 |
|
| 87 | +#ifdef PADDLE_WITH_XPU_FFT |
| 88 | +template <> |
| 89 | +void RemainderKernel<phi::dtype::complex<float>, XPUContext>( |
| 90 | + const XPUContext& dev_ctx, |
| 91 | + const DenseTensor& x, |
| 92 | + const DenseTensor& y, |
| 93 | + DenseTensor* out) { |
| 94 | + using T = phi::dtype::complex<float>; |
| 95 | + auto x_dims = x.dims(); |
| 96 | + auto y_dims = y.dims(); |
| 97 | + auto out_dims = phi::funcs::BroadcastTwoDims(x_dims, y_dims); |
| 98 | + std::vector<int64_t> out_dims_vec = phi::vectorize(out_dims); |
| 99 | + |
| 100 | + auto complex_expand = [](const XPUContext& dev_ctx, |
| 101 | + const DenseTensor& x, |
| 102 | + const std::vector<int64_t>& out_dims_vec, |
| 103 | + DenseTensor* out) { |
| 104 | + DenseTensor real_out, imag_out; |
| 105 | + real_out.Resize(out->dims()); |
| 106 | + imag_out.Resize(out->dims()); |
| 107 | + dev_ctx.template Alloc<float>(&real_out); |
| 108 | + dev_ctx.template Alloc<float>(&imag_out); |
| 109 | + const DenseTensor real = Real<T, XPUContext>(dev_ctx, x); |
| 110 | + const DenseTensor imag = Imag<T, XPUContext>(dev_ctx, x); |
| 111 | + ExpandKernel<float, XPUContext>( |
| 112 | + dev_ctx, real, phi::IntArray(out_dims_vec), &real_out); |
| 113 | + ExpandKernel<float, XPUContext>( |
| 114 | + dev_ctx, imag, phi::IntArray(out_dims_vec), &imag_out); |
| 115 | + phi::ComplexKernel<float>(dev_ctx, real_out, imag_out, out); |
| 116 | + }; |
| 117 | + |
| 118 | + DenseTensor broadcasted_x, broadcasted_y; |
| 119 | + T* x_data = nullptr; |
| 120 | + T* y_data = nullptr; |
| 121 | + |
| 122 | + if (x_dims == out_dims) { |
| 123 | + x_data = const_cast<T*>(x.data<T>()); |
| 124 | + } else { |
| 125 | + broadcasted_x.Resize(out_dims); |
| 126 | + dev_ctx.template Alloc<T>(&broadcasted_x); |
| 127 | + complex_expand(dev_ctx, x, out_dims_vec, &broadcasted_x); |
| 128 | + x_data = broadcasted_x.data<T>(); |
| 129 | + } |
| 130 | + |
| 131 | + if (y_dims == out_dims) { |
| 132 | + y_data = const_cast<T*>(y.data<T>()); |
| 133 | + } else { |
| 134 | + broadcasted_y.Resize(out_dims); |
| 135 | + dev_ctx.template Alloc<T>(&broadcasted_y); |
| 136 | + complex_expand(dev_ctx, y, out_dims_vec, &broadcasted_y); |
| 137 | + y_data = broadcasted_y.data<T>(); |
| 138 | + } |
| 139 | + |
| 140 | + dev_ctx.template Alloc<T>(out); |
| 141 | + int r = xfft_internal::xpu::RemainderFunctor( |
| 142 | + out->numel(), |
| 143 | + reinterpret_cast<cuFloatComplex*>(x_data), |
| 144 | + reinterpret_cast<cuFloatComplex*>(y_data), |
| 145 | + reinterpret_cast<cuFloatComplex*>(out->data<T>())); |
| 146 | + PADDLE_ENFORCE_XPU_SUCCESS(r); |
| 147 | +} |
| 148 | +#endif |
| 149 | + |
78 | 150 | } // namespace phi |
79 | 151 |
|
80 | 152 | PD_REGISTER_KERNEL(floor_divide, |
@@ -110,8 +182,12 @@ PD_REGISTER_KERNEL(remainder, |
110 | 182 | phi::RemainderKernel, |
111 | 183 | float, |
112 | 184 | phi::dtype::float16, |
| 185 | +#ifdef PADDLE_WITH_XPU_FFT |
| 186 | + phi::dtype::complex<float>, |
| 187 | +#endif |
113 | 188 | int32_t, |
114 | | - int64_t) {} |
| 189 | + int64_t) { |
| 190 | +} |
115 | 191 | PD_REGISTER_KERNEL(elementwise_pow, |
116 | 192 | XPU, |
117 | 193 | ALL_LAYOUT, |
|
0 commit comments