Skip to content

Commit 64a826b

Browse files
authored
[XPU] Support complex for remainder (PaddlePaddle#73818)
* add complex * fix bug * Empty commit to trigger CI
1 parent b497bd9 commit 64a826b

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ else()
5555
endif()
5656

5757
if(NOT DEFINED XPU_FFT_BASE_DATE)
58-
set(XPU_FFT_BASE_DATE "20250630")
58+
set(XPU_FFT_BASE_DATE "20250704")
5959
endif()
6060

6161
set(XPU_XRE_BASE_URL

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,9 @@ XPUOpMap& get_kl3_ops() {
490490
{"elementwise_mod",
491491
XPUKernelSet({phi::DataType::FLOAT32,
492492
phi::DataType::FLOAT16,
493+
#ifdef PADDLE_WITH_XPU_FFT
494+
phi::DataType::COMPLEX64,
495+
#endif
493496
phi::DataType::INT64,
494497
phi::DataType::INT32})},
495498
{"embedding_with_eltwise_add_xpu",

paddle/phi/kernels/xpu/elementwise_kernel.cc

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818

1919
#include "paddle/phi/backends/xpu/xpu_context.h"
2020
#include "paddle/phi/core/kernel_registry.h"
21+
#ifdef PADDLE_WITH_XPU_FFT
22+
#include "fft/cuComplex.h"
23+
#include "paddle/phi/kernels/complex_kernel.h"
24+
#include "paddle/phi/kernels/expand_kernel.h"
25+
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
26+
namespace xfft_internal::xpu {
27+
int RemainderFunctor(int N, float2* input_x, float2* input_y, float2* output);
28+
}
29+
#endif
2130

2231
namespace phi {
2332

@@ -75,6 +84,69 @@ void ElementwisePowKernel(const Context& dev_ctx,
7584
ElementwisePowRawKernel<T>(dev_ctx, x, y, axis, out);
7685
}
7786

87+
#ifdef PADDLE_WITH_XPU_FFT
88+
template <>
89+
void RemainderKernel<phi::dtype::complex<float>, XPUContext>(
90+
const XPUContext& dev_ctx,
91+
const DenseTensor& x,
92+
const DenseTensor& y,
93+
DenseTensor* out) {
94+
using T = phi::dtype::complex<float>;
95+
auto x_dims = x.dims();
96+
auto y_dims = y.dims();
97+
auto out_dims = phi::funcs::BroadcastTwoDims(x_dims, y_dims);
98+
std::vector<int64_t> out_dims_vec = phi::vectorize(out_dims);
99+
100+
auto complex_expand = [](const XPUContext& dev_ctx,
101+
const DenseTensor& x,
102+
const std::vector<int64_t>& out_dims_vec,
103+
DenseTensor* out) {
104+
DenseTensor real_out, imag_out;
105+
real_out.Resize(out->dims());
106+
imag_out.Resize(out->dims());
107+
dev_ctx.template Alloc<float>(&real_out);
108+
dev_ctx.template Alloc<float>(&imag_out);
109+
const DenseTensor real = Real<T, XPUContext>(dev_ctx, x);
110+
const DenseTensor imag = Imag<T, XPUContext>(dev_ctx, x);
111+
ExpandKernel<float, XPUContext>(
112+
dev_ctx, real, phi::IntArray(out_dims_vec), &real_out);
113+
ExpandKernel<float, XPUContext>(
114+
dev_ctx, imag, phi::IntArray(out_dims_vec), &imag_out);
115+
phi::ComplexKernel<float>(dev_ctx, real_out, imag_out, out);
116+
};
117+
118+
DenseTensor broadcasted_x, broadcasted_y;
119+
T* x_data = nullptr;
120+
T* y_data = nullptr;
121+
122+
if (x_dims == out_dims) {
123+
x_data = const_cast<T*>(x.data<T>());
124+
} else {
125+
broadcasted_x.Resize(out_dims);
126+
dev_ctx.template Alloc<T>(&broadcasted_x);
127+
complex_expand(dev_ctx, x, out_dims_vec, &broadcasted_x);
128+
x_data = broadcasted_x.data<T>();
129+
}
130+
131+
if (y_dims == out_dims) {
132+
y_data = const_cast<T*>(y.data<T>());
133+
} else {
134+
broadcasted_y.Resize(out_dims);
135+
dev_ctx.template Alloc<T>(&broadcasted_y);
136+
complex_expand(dev_ctx, y, out_dims_vec, &broadcasted_y);
137+
y_data = broadcasted_y.data<T>();
138+
}
139+
140+
dev_ctx.template Alloc<T>(out);
141+
int r = xfft_internal::xpu::RemainderFunctor(
142+
out->numel(),
143+
reinterpret_cast<cuFloatComplex*>(x_data),
144+
reinterpret_cast<cuFloatComplex*>(y_data),
145+
reinterpret_cast<cuFloatComplex*>(out->data<T>()));
146+
PADDLE_ENFORCE_XPU_SUCCESS(r);
147+
}
148+
#endif
149+
78150
} // namespace phi
79151

80152
PD_REGISTER_KERNEL(floor_divide,
@@ -110,8 +182,12 @@ PD_REGISTER_KERNEL(remainder,
110182
phi::RemainderKernel,
111183
float,
112184
phi::dtype::float16,
185+
#ifdef PADDLE_WITH_XPU_FFT
186+
phi::dtype::complex<float>,
187+
#endif
113188
int32_t,
114-
int64_t) {}
189+
int64_t) {
190+
}
115191
PD_REGISTER_KERNEL(elementwise_pow,
116192
XPU,
117193
ALL_LAYOUT,

test/xpu/test_elementwise_mod_op_xpu.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from op_test import OpTest
2424
from op_test_xpu import XPUOpTest
25+
from utils import dygraph_guard
2526

2627
import paddle
2728
from paddle import base
@@ -109,8 +110,41 @@ def test_dygraph(self):
109110

110111

111112
support_types = get_xpu_op_support_types('elementwise_mod')
112-
for stype in support_types:
113+
real_types = [t for t in support_types if t != 'complex64']
114+
for stype in real_types:
113115
create_test_class(globals(), XPUTestElementwiseModOp, stype)
114116

117+
if 'complex64' in support_types:
118+
119+
class TestElementwiseModOpComplex64(unittest.TestCase):
120+
def test_check_output(self):
121+
with dygraph_guard():
122+
dtype = "complex64"
123+
a = np.array([6 + 4j]).astype(dtype)
124+
b = np.array([3 + 5j]).astype(dtype)
125+
res = np.array([-2 + 2j]).astype(dtype)
126+
127+
res_pd = paddle.remainder(
128+
paddle.to_tensor(a), paddle.to_tensor(b)
129+
)
130+
np.testing.assert_allclose(res, res_pd.numpy())
131+
132+
dtype = "complex64"
133+
a = np.array([6 + 4j]).astype(dtype)
134+
b = np.array([3 + 5j]).astype(dtype)
135+
res = np.array([-2 + 2j]).astype(dtype)
136+
137+
res_pd = paddle.remainder(
138+
paddle.to_tensor(a), paddle.to_tensor(b)
139+
)
140+
np.testing.assert_allclose(res, res_pd.numpy())
141+
142+
with base.device_guard("xpu"):
143+
res_pd = paddle.remainder(
144+
paddle.to_tensor(a), paddle.to_tensor(b)
145+
)
146+
np.testing.assert_allclose(res, res_pd.numpy())
147+
148+
115149
if __name__ == '__main__':
116150
unittest.main()

0 commit comments

Comments
 (0)