Skip to content

Commit e682f7c

Browse files
committed
Add numpy_test for ldexp
1 parent 4bbd938 commit e682f7c

File tree

8 files changed

+87
-14
lines changed

8 files changed

+87
-14
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,13 @@ def lcm(x1, x2):
848848
def ldexp(x1, x2):
849849
x1 = convert_to_tensor(x1)
850850
x2 = convert_to_tensor(x2)
851+
852+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
853+
raise TypeError(
854+
f"ldexp exponent must be an integer type. "
855+
f"Received: x2 dtype={x2.dtype}"
856+
)
857+
851858
return jnp.ldexp(x1, x2)
852859

853860

keras/src/backend/numpy/numpy.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,16 @@ def lcm(x1, x2):
776776
def ldexp(x1, x2):
777777
x1 = convert_to_tensor(x1)
778778
x2 = convert_to_tensor(x2)
779-
dtype = dtypes.result_type(x1.dtype, x2.dtype)
779+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
780+
781+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
782+
raise TypeError(
783+
f"ldexp exponent must be an integer type. "
784+
f"Received: x2 dtype={x2.dtype}"
785+
)
786+
787+
x1 = np.asarray(x1).astype(np.float32)
788+
x2 = np.asarray(x2).astype(np.int32)
780789
return np.ldexp(x1, x2).astype(dtype)
781790

782791

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ NumpyDtypeTest::test_isin
2929
NumpyDtypeTest::test_isreal
3030
NumpyDtypeTest::test_kron
3131
NumpyDtypeTest::test_lcm
32+
NumpyDtypeTest::test_ldexp
3233
NumpyDtypeTest::test_logaddexp2
3334
NumpyDtypeTest::test_matmul_
3435
NumpyDtypeTest::test_maximum_python_types
@@ -108,6 +109,7 @@ NumpyTwoInputOpsCorrectnessTest::test_inner
108109
NumpyTwoInputOpsCorrectnessTest::test_isin
109110
NumpyTwoInputOpsCorrectnessTest::test_kron
110111
NumpyTwoInputOpsCorrectnessTest::test_lcm
112+
NumpyTwoInputOpsCorrectnessTest::test_ldexp
111113
NumpyTwoInputOpsCorrectnessTest::test_quantile
112114
NumpyTwoInputOpsCorrectnessTest::test_tensordot
113115
NumpyTwoInputOpsCorrectnessTest::test_vdot
@@ -131,11 +133,13 @@ NumpyTwoInputOpsDynamicShapeTest::test_hypot
131133
NumpyTwoInputOpsDynamicShapeTest::test_isin
132134
NumpyTwoInputOpsDynamicShapeTest::test_kron
133135
NumpyTwoInputOpsDynamicShapeTest::test_lcm
136+
NumpyTwoInputOpsDynamicShapeTest::test_ldexp
134137
NumpyTwoInputOpsStaticShapeTest::test_gcd
135138
NumpyTwoInputOpsStaticShapeTest::test_hypot
136139
NumpyTwoInputOpsStaticShapeTest::test_isin
137140
NumpyTwoInputOpsStaticShapeTest::test_kron
138141
NumpyTwoInputOpsStaticShapeTest::test_lcm
142+
NumpyTwoInputOpsStaticShapeTest::test_ldexp
139143
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
140144
CoreOpsBehaviorTests::test_scan_invalid_arguments
141145
CoreOpsCallsTests::test_associative_scan_basic_call

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,10 @@ def lcm(x1, x2):
11491149
raise NotImplementedError("`lcm` is not supported with openvino backend")
11501150

11511151

1152+
def ldexp(x1, x2):
1153+
raise NotImplementedError("`ldexp` is not supported with openvino backend")
1154+
1155+
11521156
def less(x1, x2):
11531157
element_type = None
11541158
if isinstance(x1, OpenVINOKerasTensor):

keras/src/backend/tensorflow/numpy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,14 @@ def lcm(x1, x2):
18471847
def ldexp(x1, x2):
18481848
x1 = convert_to_tensor(x1)
18491849
x2 = convert_to_tensor(x2)
1850-
dtype = dtypes.result_type(x1.dtype, x2.dtype)
1850+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1851+
1852+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
1853+
raise TypeError(
1854+
f"ldexp exponent must be an integer type. "
1855+
f"Received: x2 dtype={x2.dtype}"
1856+
)
1857+
18511858
x1 = tf.cast(x1, tf.float32)
18521859
x2 = tf.cast(x2, tf.float32)
18531860
return tf.cast(x1 * tf.pow(2.0, x2), dtype)

keras/src/backend/torch/numpy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,15 @@ def lcm(x1, x2):
978978
def ldexp(x1, x2):
979979
x1 = convert_to_tensor(x1)
980980
x2 = convert_to_tensor(x2)
981-
return torch.ldexp(x1, x2)
981+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
982+
983+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
984+
raise TypeError(
985+
f"ldexp exponent must be an integer type. "
986+
f"Received: x2 dtype={x2.dtype}"
987+
)
988+
989+
return cast(torch.ldexp(x1, x2), dtype)
982990

983991

984992
def less(x1, x2):

keras/src/ops/numpy.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4070,7 +4070,7 @@ def compute_output_spec(self, x1, x2):
40704070

40714071
x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1)))
40724072
x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2)))
4073-
dtype = dtypes.result_type(x1_type, x2_type)
4073+
dtype = dtypes.result_type(x1_type, x2_type, float)
40744074
return KerasTensor(output_shape, dtype=dtype)
40754075

40764076

@@ -4081,19 +4081,12 @@ def ldexp(x1, x2):
40814081
This function computes:
40824082
ldexp(x1, x2) = x1 * 2**x2
40834083
4084-
Notes:
4085-
- TensorFlow does *not* provide a built-in `tf.math.ldexp`.
4086-
- The `backend.numpy.ldexp` implementation in TF NumPy is currently
4087-
unreliable due to a naming-collision bug and should not be used.
4088-
- This implementation provides correct broadcasting, dtype inference,
4089-
and gradient support.
4090-
40914084
Args:
4092-
x1: Floating-point input tensor.
4093-
x2: Integer or floating-point exponent tensor.
4085+
x1: Float input tensor.
4086+
x2: Integer exponent tensor.
40944087
40954088
Returns:
4096-
Output tensor, element-wise equal to `x1 * 2**x2`.
4089+
Output tensor
40974090
40984091
Example:
40994092
>>> x1 = keras.ops.convert_to_tensor([0.75, 1.5])

keras/src/ops/numpy_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ def test_lcm(self):
254254
y = KerasTensor((2, None))
255255
self.assertEqual(knp.lcm(x, y).shape, (2, 3))
256256

257+
def test_ldexp(self):
258+
x = KerasTensor((None, 3))
259+
y = KerasTensor((1, 3))
260+
self.assertEqual(knp.ldexp(x, y).shape, (None, 3))
261+
257262
def test_less(self):
258263
x = KerasTensor((None, 3))
259264
y = KerasTensor((2, None))
@@ -837,6 +842,15 @@ def test_lcm(self):
837842
y = KerasTensor((2, 3))
838843
self.assertEqual(knp.lcm(x, y).shape, (2, 3))
839844

845+
def test_ldexp(self):
846+
x = KerasTensor((2, 3))
847+
y = KerasTensor((2, 3))
848+
self.assertEqual(knp.ldexp(x, y).shape, (2, 3))
849+
850+
x = KerasTensor((2, 3))
851+
y = KerasTensor((1, 3))
852+
self.assertEqual(knp.ldexp(x, y).shape, (2, 3))
853+
840854
def test_less(self):
841855
x = KerasTensor((2, 3))
842856
y = KerasTensor((2, 3))
@@ -3114,6 +3128,12 @@ def test_lcm(self):
31143128
self.assertAllClose(knp.lcm(x, y), np.lcm(x, y))
31153129
self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y))
31163130

3131+
def test_ldexp(self):
3132+
x = np.array([[1, 2, 3], [3, 2, 1]])
3133+
y = np.array([[4, 5, 6], [3, 2, 1]])
3134+
self.assertAllClose(knp.ldexp(x, y), np.ldexp(x, y))
3135+
self.assertAllClose(knp.Ldexp()(x, y), np.ldexp(x, y))
3136+
31173137
def test_less(self):
31183138
x = np.array([[1, 2, 3], [3, 2, 1]])
31193139
y = np.array([[4, 5, 6], [3, 2, 1]])
@@ -7884,6 +7904,27 @@ def test_lcm(self, dtypes):
78847904
expected_dtype,
78857905
)
78867906

7907+
@parameterized.named_parameters(
7908+
named_product(dtypes=list(itertools.product(ALL_DTYPES, INT_DTYPES)))
7909+
)
7910+
def test_ldexp(self, dtypes):
7911+
import jax.numpy as jnp
7912+
7913+
dtype1, dtype2 = dtypes
7914+
x1 = knp.ones((), dtype=dtype1)
7915+
x2 = knp.ones((), dtype=dtype2)
7916+
x1_jax = jnp.ones((), dtype=dtype1)
7917+
x2_jax = jnp.ones((), dtype=dtype2)
7918+
expected_dtype = standardize_dtype(jnp.ldexp(x1_jax, x2_jax).dtype)
7919+
7920+
self.assertEqual(
7921+
standardize_dtype(knp.ldexp(x1, x2).dtype), expected_dtype
7922+
)
7923+
self.assertEqual(
7924+
standardize_dtype(knp.Ldexp().symbolic_call(x1, x2).dtype),
7925+
expected_dtype,
7926+
)
7927+
78877928
@parameterized.named_parameters(
78887929
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
78897930
)

0 commit comments

Comments
 (0)