Skip to content

Commit 8f04616

Browse files
Specify window_length dtype requirement in tf.keras.ops.istft in math.py (#20728)
The `window_length` parameter in `tf.keras.ops.istft` requires `tf.int32` dtype, but this isn't documented. This can cause unexpected `ValueError` when using `tf.int64` and `tf.int16` Here is the Example case: ``` import tensorflow as tf input_dict = { 'stfts': tf.constant([[-0.87817144+1.14583987j, -0.32066484+0.25565411j]], dtype=tf.complex128), 'frame_length': tf.constant(256, dtype=tf.int16), 'frame_step': tf.constant(5120,dtype=tf.int64) } result = tf.signal.inverse_stft(**input_dict) print(result) ``` The code throws the following error: ``` ValueError: window_length: Tensor conversion requested dtype int32 for Tensor with dtype int64 ```
1 parent 1adaaec commit 8f04616

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

keras/src/ops/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def istft(
888888
sequence_length: An integer representing the sequence length.
889889
sequence_stride: An integer representing the sequence hop size.
890890
fft_length: An integer representing the size of the FFT that produced
891-
`stft`.
891+
`stft`. Should be of type `int32`.
892892
length: An integer representing the output is clipped to exactly length.
893893
If not specified, no padding or clipping take place. Defaults to
894894
`None`.

0 commit comments

Comments
 (0)