@@ -254,6 +254,11 @@ def test_lcm(self):
254254 y = KerasTensor ((2 , None ))
255255 self .assertEqual (knp .lcm (x , y ).shape , (2 , 3 ))
256256
257+ def test_ldexp (self ):
258+ x = KerasTensor ((None , 3 ))
259+ y = KerasTensor ((1 , 3 ))
260+ self .assertEqual (knp .ldexp (x , y ).shape , (None , 3 ))
261+
257262 def test_less (self ):
258263 x = KerasTensor ((None , 3 ))
259264 y = KerasTensor ((2 , None ))
@@ -837,6 +842,15 @@ def test_lcm(self):
837842 y = KerasTensor ((2 , 3 ))
838843 self .assertEqual (knp .lcm (x , y ).shape , (2 , 3 ))
839844
845+ def test_ldexp (self ):
846+ x = KerasTensor ((2 , 3 ))
847+ y = KerasTensor ((2 , 3 ))
848+ self .assertEqual (knp .ldexp (x , y ).shape , (2 , 3 ))
849+
850+ x = KerasTensor ((2 , 3 ))
851+ y = KerasTensor ((1 , 3 ))
852+ self .assertEqual (knp .ldexp (x , y ).shape , (2 , 3 ))
853+
840854 def test_less (self ):
841855 x = KerasTensor ((2 , 3 ))
842856 y = KerasTensor ((2 , 3 ))
@@ -3114,6 +3128,12 @@ def test_lcm(self):
31143128 self .assertAllClose (knp .lcm (x , y ), np .lcm (x , y ))
31153129 self .assertAllClose (knp .Lcm ()(x , y ), np .lcm (x , y ))
31163130
3131+ def test_ldexp (self ):
3132+ x = np .array ([[1 , 2 , 3 ], [3 , 2 , 1 ]])
3133+ y = np .array ([[4 , 5 , 6 ], [3 , 2 , 1 ]])
3134+ self .assertAllClose (knp .ldexp (x , y ), np .ldexp (x , y ))
3135+ self .assertAllClose (knp .Ldexp ()(x , y ), np .ldexp (x , y ))
3136+
31173137 def test_less (self ):
31183138 x = np .array ([[1 , 2 , 3 ], [3 , 2 , 1 ]])
31193139 y = np .array ([[4 , 5 , 6 ], [3 , 2 , 1 ]])
@@ -7884,6 +7904,27 @@ def test_lcm(self, dtypes):
78847904 expected_dtype ,
78857905 )
78867906
7907+ @parameterized .named_parameters (
7908+ named_product (dtypes = list (itertools .product (ALL_DTYPES , INT_DTYPES )))
7909+ )
7910+ def test_ldexp (self , dtypes ):
7911+ import jax .numpy as jnp
7912+
7913+ dtype1 , dtype2 = dtypes
7914+ x1 = knp .ones ((), dtype = dtype1 )
7915+ x2 = knp .ones ((), dtype = dtype2 )
7916+ x1_jax = jnp .ones ((), dtype = dtype1 )
7917+ x2_jax = jnp .ones ((), dtype = dtype2 )
7918+ expected_dtype = standardize_dtype (jnp .ldexp (x1_jax , x2_jax ).dtype )
7919+
7920+ self .assertEqual (
7921+ standardize_dtype (knp .ldexp (x1 , x2 ).dtype ), expected_dtype
7922+ )
7923+ self .assertEqual (
7924+ standardize_dtype (knp .Ldexp ().symbolic_call (x1 , x2 ).dtype ),
7925+ expected_dtype ,
7926+ )
7927+
78877928 @parameterized .named_parameters (
78887929 named_product (dtypes = itertools .combinations (ALL_DTYPES , 2 ))
78897930 )
0 commit comments