Skip to content

Commit d82eb97

Browse files
committed
align api
1 parent 5ce27ed commit d82eb97

File tree

4 files changed

+87
-18
lines changed

4 files changed

+87
-18
lines changed

paddle/phi/kernels/cpu/addcmul_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
#include "paddle/phi/kernels/impl/addcmul_kernel_impl.h"
2020

2121
PD_REGISTER_KERNEL(
22-
addcmul, CPU, ALL_LAYOUT, phi::AddcmulKernel, float, double) {}
22+
addcmul, CPU, ALL_LAYOUT, phi::AddcmulKernel, float, double, int, int64_t) {
23+
}

paddle/phi/kernels/gpu/addcmul_kernel.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(addcmul,
2525
float,
2626
double,
2727
phi::dtype::float16,
28-
phi::dtype::bfloat16) {}
28+
phi::dtype::bfloat16,
29+
int,
30+
int64_t) {}

python/paddle/_paddle_docs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4823,11 +4823,11 @@ def i1e(
48234823
48244824
Args:
48254825
input (Tensor): The input tensor to be added to the final result.
4826-
Its data type should be float16, float32, float64.
4826+
Its data type should be int32, int64, float16, float32, float64.
48274827
tensor1 (Tensor): The first tensor for element-wise multiplication.
4828-
Its data type should be float16, float32, float64.
4828+
Its data type should be int32, int64, float16, float32, float64.
48294829
tensor2 (Tensor): The second tensor for element-wise multiplication.
4830-
Its data type should be float16, float32, float64.
4830+
Its data type should be int32, int64, float16, float32, float64.
48314831
value (float, optional): The scalar multiplier for tensor1 * tensor2. Default: 1.
48324832
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
48334833
out (Tensor|None, optional): The output tensor. Default: None.

test/legacy_test/test_addcmul_op.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ class TestAddcmulOp(OpTest):
3232

3333
def setUp(self):
3434
self.op_type = "addcmul"
35-
self.prim_op_type = "comp"
3635
self.python_api = paddle.addcmul
3736
self.public_python_api = paddle.addcmul
3837
self.init_dtype_type()
@@ -142,6 +141,58 @@ def init_shapes_and_data(self):
142141
self.attrs = {'value': 0.5}
143142

144143

144+
class TestAddcmulOp_Int32(OpTest):
145+
"""Test int32 dtype - aligned with PyTorch integer type support"""
146+
147+
def setUp(self):
148+
self.op_type = "addcmul"
149+
self.python_api = paddle.addcmul
150+
self.public_python_api = paddle.addcmul
151+
self.dtype = np.int32
152+
153+
input_np = np.random.randint(1, 10, (10, 20)).astype(self.dtype)
154+
tensor1_np = np.random.randint(1, 10, (10, 20)).astype(self.dtype)
155+
tensor2_np = np.random.randint(1, 10, (10, 20)).astype(self.dtype)
156+
value = 2
157+
158+
self.inputs = {
159+
'input': input_np,
160+
'tensor1': tensor1_np,
161+
'tensor2': tensor2_np,
162+
}
163+
self.attrs = {'value': value}
164+
self.outputs = {'out': input_np + value * tensor1_np * tensor2_np}
165+
166+
def test_check_output(self):
167+
self.check_output(check_pir=True)
168+
169+
170+
class TestAddcmulOp_Int64(OpTest):
171+
"""Test int64 dtype - aligned with PyTorch integer type support"""
172+
173+
def setUp(self):
174+
self.op_type = "addcmul"
175+
self.python_api = paddle.addcmul
176+
self.public_python_api = paddle.addcmul
177+
self.dtype = np.int64
178+
179+
input_np = np.random.randint(1, 10, (10, 20)).astype(self.dtype)
180+
tensor1_np = np.random.randint(1, 10, (10, 20)).astype(self.dtype)
181+
tensor2_np = np.random.randint(1, 10, (10, 20)).astype(self.dtype)
182+
value = 3
183+
184+
self.inputs = {
185+
'input': input_np,
186+
'tensor1': tensor1_np,
187+
'tensor2': tensor2_np,
188+
}
189+
self.attrs = {'value': value}
190+
self.outputs = {'out': input_np + value * tensor1_np * tensor2_np}
191+
192+
def test_check_output(self):
193+
self.check_output(check_pir=True)
194+
195+
145196
@unittest.skipIf(
146197
not core.is_compiled_with_cuda(),
147198
"core is not compiled with CUDA",
@@ -171,7 +222,6 @@ class TestAddcmulBF16Op(OpTest):
171222

172223
def setUp(self):
173224
self.op_type = "addcmul"
174-
self.prim_op_type = "comp"
175225
self.python_api = paddle.addcmul
176226
self.public_python_api = paddle.addcmul
177227
self.dtype = np.uint16
@@ -209,7 +259,6 @@ class TestAddcmulBroadcast2D(OpTest):
209259

210260
def setUp(self):
211261
self.op_type = "addcmul"
212-
self.prim_op_type = "comp"
213262
self.python_api = paddle.addcmul
214263
self.public_python_api = paddle.addcmul
215264
self.dtype = np.float64
@@ -246,7 +295,6 @@ class TestAddcmulBroadcast3D(OpTest):
246295

247296
def setUp(self):
248297
self.op_type = "addcmul"
249-
self.prim_op_type = "comp"
250298
self.python_api = paddle.addcmul
251299
self.public_python_api = paddle.addcmul
252300
self.dtype = np.float64
@@ -284,11 +332,9 @@ class TestAddcmulOpError(unittest.TestCase):
284332
def test_type_errors(self):
285333
paddle.enable_static()
286334
with program_guard(Program(), Program()):
287-
input = paddle.static.data(
288-
name='input', shape=[4, 4], dtype="int32"
289-
)
290-
x3 = paddle.static.data(name='x3', shape=[4, 4], dtype="int32")
291-
x4 = paddle.static.data(name='x4', shape=[4, 4], dtype="int32")
335+
input = paddle.static.data(name='input', shape=[4, 4], dtype="bool")
336+
x3 = paddle.static.data(name='x3', shape=[4, 4], dtype="bool")
337+
x4 = paddle.static.data(name='x4', shape=[4, 4], dtype="bool")
292338
self.assertRaises(TypeError, paddle.addcmul, input, x3, x4)
293339
paddle.disable_static()
294340

@@ -334,8 +380,8 @@ def test_dygraph_api(self):
334380
def test_static_api(self):
335381
"""Test static graph API"""
336382
paddle.enable_static()
337-
main = paddle.static.Program()
338383
startup = paddle.static.Program()
384+
main = paddle.static.Program()
339385
with base.program_guard(main, startup):
340386
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
341387
t1 = paddle.static.data(
@@ -346,9 +392,10 @@ def test_static_api(self):
346392
)
347393
out = paddle.addcmul(x, t1, t2, value=0.5)
348394

349-
exe = base.Executor(paddle.CPUPlace())
395+
place = paddle.CPUPlace()
396+
exe = base.Executor(place)
350397
result = exe.run(
351-
main,
398+
base.default_main_program(),
352399
feed={
353400
"x": self.np_input,
354401
"tensor1": self.np_tensor1,
@@ -374,6 +421,22 @@ def test_out_parameter(self):
374421
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-5)
375422
paddle.enable_static()
376423

424+
def test_int_dtype(self):
425+
"""Test integer type support (aligned with PyTorch)"""
426+
paddle.disable_static()
427+
np_input = np.array([[1, 2], [3, 4]], dtype='int32')
428+
np_t1 = np.array([[2, 3], [4, 5]], dtype='int32')
429+
np_t2 = np.array([[1, 1], [2, 2]], dtype='int32')
430+
431+
x = paddle.to_tensor(np_input)
432+
t1 = paddle.to_tensor(np_t1)
433+
t2 = paddle.to_tensor(np_t2)
434+
435+
out = paddle.addcmul(x, t1, t2, value=2)
436+
expected = np_input + 2 * np_t1 * np_t2
437+
np.testing.assert_array_equal(out.numpy(), expected)
438+
paddle.enable_static()
439+
377440

378441
class TestAddcmulGradEmptyTensor(unittest.TestCase):
379442
"""Test gradient with empty tensors - covers numel==0 branch"""
@@ -513,6 +576,8 @@ class TestAddcmulCINNSymbolic(unittest.TestCase):
513576
def setUp(self):
514577
if not core.is_compiled_with_cuda():
515578
self.skipTest("CINN requires CUDA")
579+
if not core.is_compiled_with_cinn():
580+
self.skipTest("CINN is not compiled")
516581
paddle.disable_static()
517582
paddle.seed(2024)
518583

@@ -648,6 +713,8 @@ class TestAddcmulCINNGrad(unittest.TestCase):
648713
def setUp(self):
649714
if not core.is_compiled_with_cuda():
650715
self.skipTest("CINN requires CUDA")
716+
if not core.is_compiled_with_cinn():
717+
self.skipTest("CINN is not compiled")
651718
paddle.disable_static()
652719
paddle.seed(2024)
653720

@@ -720,7 +787,6 @@ def addcmul_loss(x, t1, t2):
720787

721788
# ============================================================
722789
# OpTest broadcast tests for high ranks (backward reduction)
723-
# These do NOT set prim_op_type to force C++ kernel backward path
724790
# ============================================================
725791

726792

0 commit comments

Comments
 (0)