Skip to content

Commit 8b0c561

Browse files
committed
Update code by gemini reveiw
1 parent e682f7c commit 8b0c561

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

keras/src/backend/numpy/numpy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,6 @@ def ldexp(x1, x2):
784784
f"Received: x2 dtype={x2.dtype}"
785785
)
786786

787-
x1 = np.asarray(x1).astype(np.float32)
788-
x2 = np.asarray(x2).astype(np.int32)
789787
return np.ldexp(x1, x2).astype(dtype)
790788

791789

keras/src/backend/tensorflow/numpy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,9 +1855,12 @@ def ldexp(x1, x2):
18551855
f"Received: x2 dtype={x2.dtype}"
18561856
)
18571857

1858-
x1 = tf.cast(x1, tf.float32)
1859-
x2 = tf.cast(x2, tf.float32)
1860-
return tf.cast(x1 * tf.pow(2.0, x2), dtype)
1858+
x1 = tf.cast(x1, dtypes.result_type(x1.dtype, float))
1859+
1860+
x1 = tf.cast(x1, tf.float32 if not x1.dtype.is_floating else x1.dtype)
1861+
x2 = tf.cast(x2, x1.dtype)
1862+
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
1863+
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
18611864

18621865

18631866
def less(x1, x2):

0 commit comments

Comments
 (0)