Skip to content

Commit dd1ab80

Browse files
committed
fix: ivy.diff for tf backend
1 parent 6d1ee32 commit dd1ab80

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

ivy/functional/backends/tensorflow/experimental/elementwise.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,24 @@ def diff(
243243
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
244244
) -> Union[tf.Tensor, tf.Variable]:
245245
if n == 0:
246-
return x
246+
return tf.experimental.numpy.asanyarray(x)
247+
248+
x = tf.convert_to_tensor(x)
249+
247250
if prepend is not None:
248-
x = tf.experimental.numpy.append(prepend, x, axis=axis if axis != -1 else None)
251+
prepend = tf.convert_to_tensor(prepend)
252+
promoted_type = tf.experimental.numpy.result_type(x.dtype, prepend.dtype)
253+
x = tf.concat(
254+
[tf.cast(prepend, promoted_type), tf.cast(x, promoted_type)], axis=axis
255+
)
256+
249257
if append is not None:
250-
x = tf.experimental.numpy.append(x, append, axis=axis if axis != -1 else None)
258+
append = tf.convert_to_tensor(append)
259+
promoted_type = tf.experimental.numpy.result_type(x.dtype, append.dtype)
260+
x = tf.concat(
261+
[tf.cast(x, promoted_type), tf.cast(append, promoted_type)], axis=axis
262+
)
263+
251264
return tf.experimental.numpy.diff(x, n=n, axis=axis)
252265

253266

0 commit comments

Comments
 (0)