diff --git a/keras/api/_tf_keras/keras/initializers/__init__.py b/keras/api/_tf_keras/keras/initializers/__init__.py index 5819d1b285eb..405fa1142812 100644 --- a/keras/api/_tf_keras/keras/initializers/__init__.py +++ b/keras/api/_tf_keras/keras/initializers/__init__.py @@ -16,6 +16,7 @@ from keras.src.initializers.constant_initializers import Identity as identity from keras.src.initializers.constant_initializers import Ones from keras.src.initializers.constant_initializers import Ones as ones +from keras.src.initializers.constant_initializers import STFTInitializer from keras.src.initializers.constant_initializers import Zeros from keras.src.initializers.constant_initializers import Zeros as zeros from keras.src.initializers.initializer import Initializer diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 7c905b9efad2..370ae1358c7d 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -175,6 +175,7 @@ from keras.src.layers.preprocessing.normalization import Normalization from keras.src.layers.preprocessing.pipeline import Pipeline from keras.src.layers.preprocessing.rescaling import Rescaling +from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram from keras.src.layers.preprocessing.string_lookup import StringLookup from keras.src.layers.preprocessing.text_vectorization import TextVectorization from keras.src.layers.regularization.activity_regularization import ( diff --git a/keras/api/initializers/__init__.py b/keras/api/initializers/__init__.py index 5819d1b285eb..405fa1142812 100644 --- a/keras/api/initializers/__init__.py +++ b/keras/api/initializers/__init__.py @@ -16,6 +16,7 @@ from keras.src.initializers.constant_initializers import Identity as identity from keras.src.initializers.constant_initializers import Ones from keras.src.initializers.constant_initializers import Ones as ones +from keras.src.initializers.constant_initializers import STFTInitializer from keras.src.initializers.constant_initializers import Zeros from keras.src.initializers.constant_initializers import Zeros as zeros from keras.src.initializers.initializer import Initializer diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 2c1b3d576434..2ae17dcfbd93 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -175,6 +175,7 @@ from keras.src.layers.preprocessing.normalization import Normalization from keras.src.layers.preprocessing.pipeline import Pipeline from keras.src.layers.preprocessing.rescaling import Rescaling +from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram from keras.src.layers.preprocessing.string_lookup import StringLookup from keras.src.layers.preprocessing.text_vectorization import TextVectorization from keras.src.layers.regularization.activity_regularization import ( diff --git a/keras/src/initializers/__init__.py b/keras/src/initializers/__init__.py index e7cf6f76e3ef..d3208c470fc6 100644 --- a/keras/src/initializers/__init__.py +++ b/keras/src/initializers/__init__.py @@ -4,6 +4,7 @@ from keras.src.initializers.constant_initializers import Constant from keras.src.initializers.constant_initializers import Identity from keras.src.initializers.constant_initializers import Ones +from keras.src.initializers.constant_initializers import STFTInitializer from keras.src.initializers.constant_initializers import Zeros from keras.src.initializers.initializer import Initializer from keras.src.initializers.random_initializers import GlorotNormal @@ -25,6 +26,7 @@ Constant, Identity, Ones, + STFTInitializer, Zeros, GlorotNormal, GlorotUniform, diff --git a/keras/src/initializers/constant_initializers.py b/keras/src/initializers/constant_initializers.py index c5ab6a42d6b2..2bb9274e3746 100644 --- a/keras/src/initializers/constant_initializers.py +++ b/keras/src/initializers/constant_initializers.py @@ -3,6 +3,7 @@ from keras.src.backend import standardize_dtype from keras.src.initializers.initializer import Initializer from keras.src.saving import serialization_lib +from keras.src.utils.module_utils import scipy @keras_export(["keras.initializers.Constant", "keras.initializers.constant"]) @@ -151,3 +152,120 @@ def __call__(self, shape, dtype=None): ) dtype = standardize_dtype(dtype) return self.gain * ops.eye(*shape, dtype=dtype) + + +@keras_export(["keras.initializers.STFTInitializer"]) +class STFTInitializer(Initializer): + """Initializer of Conv kernels for Short-term Fourier Transformation (STFT). + + Since the formula involves complex numbers, this class compute either the + real or the imaginary components of the final output. + + Additionally, this initializer supports windowing functions across the time + dimension as commonly used in STFT. Windowing functions from the module + `scipy.signal.windows` are supported, including the common `hann` and + `hamming` windowing functions. This layer supports periodic windows and + scaling-based normalization. + + This is primarly intended for use in the `STFTSpectrogram` layer. + + Examples: + + >>> # Standalone usage: + >>> initializer = STFTInitializer("real", "hann", "density", False) + >>> values = initializer(shape=(128, 1, 513)) + + Args: + side: String, `"real"` or `"imag"` deciding if the kernel will compute + the real side or the imaginary side of the output. + window: String for the name of the windowing function in the + `scipy.signal.windows` module, or array_like for the window values, + or `None` for no windowing. + scaling: String, `"density"` or `"spectrum"` for scaling of the window + for normalization, either L2 or L1 normalization. + `None` for no scaling. + periodic: Boolean, if True, the window function will be treated as + periodic. Defaults to `False`. + """ + + def __init__(self, side, window="hann", scaling="density", periodic=False): + if side not in ["real", "imag"]: + raise ValueError(f"side should be 'real' or 'imag', not {side}") + if isinstance(window, str): + # throws an exception for invalid window function + scipy.signal.get_window(window, 1) + if scaling is not None and scaling not in ["density", "spectrum"]: + raise ValueError( + "Scaling is invalid, it must be `None`, 'density' " + f"or 'spectrum'. Received scaling={scaling}" + ) + self.side = side + self.window = window + self.scaling = scaling + self.periodic = periodic + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + The shape is assumed to be `(T, 1, F // 2 + 1)`, where `T` is the size + of the given window, and `F` is the number of frequency bands. Only half + the frequency bands are used, which is a common practice in STFT, + because the second half are the conjugates of the first half in + a reversed order. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes + are supported. If not specified, `keras.backend.floatx()` + is used, which default to `float32` unless you configured it + otherwise (via `keras.backend.set_floatx(float_dtype)`). + """ + dtype = standardize_dtype(dtype) + frame_length, input_channels, fft_length = shape + + win = None + scaling = 1 + if self.window is not None: + win = self.window + if isinstance(win, str): + # Using SciPy since it provides more windowing functions, + # easier to be compatible with multiple backends. + win = scipy.signal.get_window(win, frame_length, self.periodic) + win = ops.convert_to_tensor(win, dtype=dtype) + if len(win.shape) != 1 or win.shape[-1] != frame_length: + raise ValueError( + "The shape of `window` must be equal to [frame_length]." + f"Received: window shape={win.shape}" + ) + win = ops.reshape(win, [frame_length, 1, 1]) + if self.scaling == "density": + scaling = ops.sqrt(ops.sum(ops.square(win))) + elif self.scaling == "spectrum": + scaling = ops.sum(ops.abs(win)) + + _fft_length = (fft_length - 1) * 2 + freq = ( + ops.reshape(ops.arange(fft_length, dtype=dtype), (1, 1, fft_length)) + / _fft_length + ) + time = ops.reshape( + ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1) + ) + args = -2 * time * freq * ops.arccos(ops.cast(-1, dtype)) + + if self.side == "real": + kernel = ops.cast(ops.cos(args), dtype) + else: + kernel = ops.cast(ops.sin(args), dtype) + + if win is not None: + kernel = kernel * win / scaling + return kernel + + def get_config(self): + return { + "side": self.side, + "window": self.window, + "periodic": self.periodic, + "scaling": self.scaling, + } diff --git a/keras/src/initializers/constant_initializers_test.py b/keras/src/initializers/constant_initializers_test.py index ace475b499e1..055ff1e7dae0 100644 --- a/keras/src/initializers/constant_initializers_test.py +++ b/keras/src/initializers/constant_initializers_test.py @@ -1,4 +1,5 @@ import numpy as np +import scipy.signal from keras.src import backend from keras.src import initializers @@ -67,3 +68,65 @@ def test_identity_initializer(self): self.assertAllClose(np_values, np.eye(*shape) * gain) self.run_class_serialization_test(initializer) + + def test_stft_initializer(self): + shape = (256, 1, 513) + time_range = np.arange(256).reshape((-1, 1, 1)) + freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1)) + pi = np.arccos(np.float64(-1)) + args = -2 * pi * time_range * freq_range + + tol_kwargs = {} + if backend.backend() == "jax": + # TODO(mostafa-mahmoud): investigate the cases + # of non-small error in jax and torch + tol_kwargs = {"atol": 1e-4, "rtol": 1e-6} + + initializer = initializers.STFTInitializer("real", None) + values = backend.convert_to_numpy(initializer(shape)) + self.assertAllClose(np.cos(args), values, atol=1e-4) + self.run_class_serialization_test(initializer) + + initializer = initializers.STFTInitializer( + "real", + "hamming", + None, + True, + ) + window = scipy.signal.windows.get_window("hamming", 256, True) + window = window.astype("float64").reshape((-1, 1, 1)) + values = backend.convert_to_numpy(initializer(shape, "float64")) + self.assertAllClose(np.cos(args) * window, values, **tol_kwargs) + self.run_class_serialization_test(initializer) + + initializer = initializers.STFTInitializer( + "imag", + "tukey", + "density", + False, + ) + window = scipy.signal.windows.get_window("tukey", 256, False) + window = window.astype("float64").reshape((-1, 1, 1)) + window = window / np.sqrt(np.sum(window**2)) + values = backend.convert_to_numpy(initializer(shape, "float64")) + self.assertAllClose(np.sin(args) * window, values, **tol_kwargs) + self.run_class_serialization_test(initializer) + + initializer = initializers.STFTInitializer( + "imag", + list(range(1, 257)), + "spectrum", + ) + window = np.arange(1, 257) + window = window.astype("float64").reshape((-1, 1, 1)) + window = window / np.sum(window) + values = backend.convert_to_numpy(initializer(shape, "float64")) + self.assertAllClose(np.sin(args) * window, values, **tol_kwargs) + self.run_class_serialization_test(initializer) + + with self.assertRaises(ValueError): + initializers.STFTInitializer("imaginary") + with self.assertRaises(ValueError): + initializers.STFTInitializer("real", scaling="l2") + with self.assertRaises(ValueError): + initializers.STFTInitializer("real", window="unknown") diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 5d39266c910d..7c425cdf8136 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -119,6 +119,7 @@ from keras.src.layers.preprocessing.normalization import Normalization from keras.src.layers.preprocessing.pipeline import Pipeline from keras.src.layers.preprocessing.rescaling import Rescaling +from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram from keras.src.layers.preprocessing.string_lookup import StringLookup from keras.src.layers.preprocessing.text_vectorization import TextVectorization from keras.src.layers.regularization.activity_regularization import ( diff --git a/keras/src/layers/preprocessing/stft_spectrogram.py b/keras/src/layers/preprocessing/stft_spectrogram.py new file mode 100644 index 000000000000..736eaeb52c78 --- /dev/null +++ b/keras/src/layers/preprocessing/stft_spectrogram.py @@ -0,0 +1,384 @@ +import math +import warnings + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.utils.module_utils import scipy + + +@keras_export("keras.layers.STFTSpectrogram") +class STFTSpectrogram(layers.Layer): + """Layer to compute the Short-Time Fourier Transform (STFT) on a 1D signal. + + A layer that computes Spectrograms of the input signal to produce + a spectrogram. This layers utilizes Short-Time Fourier Transform (STFT) by + The layer computes Spectrograms based on STFT by utilizing convolution + kernels, which allows parallelization on GPUs and trainable kernels for + fine-tuning support. This layer allows different modes of output + (e.g., log-scaled magnitude, phase, power spectral density, etc.) and + provides flexibility in windowing, padding, and scaling options for the + STFT calculation. + + Examples: + + Apply it as a non-trainable preprocessing layer on 3 audio tracks of + 1 channel, 10 seconds and sampled at 16 kHz. + + >>> layer = keras.layers.STFTSpectrogram( + ... mode='log', + ... frame_length=256, + ... frame_step=128, # 50% overlap + ... fft_length=512, + ... window="hann", + ... padding="valid", + ... trainable=False, # non-trainable, preprocessing only + ... ) + >>> layer(keras.random.uniform(shape=(3, 160000, 1))).shape + (3, 1249, 257) + + Apply it as a trainable processing layer on 3 stereo audio tracks of + 2 channels, 10 seconds and sampled at 16 kHz. This is initialized as the + non-trainable layer, but then can be trained jointly within a model. + + >>> layer = keras.layers.STFTSpectrogram( + ... mode='log', + ... frame_length=256, + ... frame_step=128, # 50% overlap + ... fft_length=512, + ... window="hamming", # hamming windowing function + ... padding="same", # padding to preserve the time dimension + ... trainable=True, # trainable, this is the default in keras + ... ) + >>> layer(keras.random.uniform(shape=(3, 160000, 2))).shape + (3, 1250, 514) + + Similar to the last example, but add an extra dimension so the output is + an image to be used with image models. We apply this here on a signal of + 3 input channels to output an image tensor, hence is directly applicable + with an image model. + + >>> layer = keras.layers.STFTSpectrogram( + ... mode='log', + ... frame_length=256, + ... frame_step=128, + ... fft_length=512, + ... padding="same", + ... expand_dims=True, # this adds the extra dimension + ... ) + >>> layer(keras.random.uniform(shape=(3, 160000, 3))).shape + (3, 1250, 257, 3) + + Args: + mode: String, the output type of the spectrogram. Can be one of + `"log"`, `"magnitude`", `"psd"`, `"real`", `"imag`", `"angle`", + `"stft`". Defaults to `"log`". + frame_length: Integer, The length of each frame (window) for STFT in + samples. Defaults to 256. + frame_step: Integer, the step size (hop length) between + consecutive frames. If not provided, defaults to half the + frame_length. Defaults to `frame_length // 2`. + fft_length: Integer, the size of frequency bins used in the Fast-Fourier + Transform (FFT) to apply to each frame. Should be greater than or + equal to `frame_length`. Recommended to be a power of two. Defaults + to the smallest power of two that is greater than or equal + to `frame_length`. + window: (String or array_like), the windowing function to apply to each + frame. Can be `"hann`" (default), `"hamming`", or a custom window + provided as an array_like. + periodic: Boolean, if True, the window function will be treated as + periodic. Defaults to `False`. + scaling: String, type of scaling applied to the window. Can be + `"density`", `"spectrum`", or None. Default is `"density`". + padding: String, padding strategy. Can be `"valid`" or `"same`". + Defaults to `"valid"`. + expand_dims: Boolean, if True, will expand the output into spectrograms + into two dimensions to be compatible with image models. + Defaults to `False`. + data_format: String, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, weight)`. Defaults to `"channels_last"`. + + Raises: + ValueError: If an invalid value is provided for `"mode`", `"scaling`", + `"padding`", or other input arguments. + TypeError: If the input data type is not one of `"float16`", + `"float32`", or `"float64`". + + Input shape: + A 3D tensor of shape `(batch_size, time_length, input_channels)`, if + `data_format=="channels_last"`, and of shape + `(batch_size, input_channels, time_length)` if + `data_format=="channels_first"`, where `time_length` is the length of + the input signal, and `input_channels` is the number of input channels. + The same kernels are applied to each channel independetly. + + Output shape: + If `data_format=="channels_first" and not expand_dims`, a 3D tensor: + `(batch_size, input_channels * freq_channels, new_time_length)` + If `data_format=="channels_last" and not expand_dims`, a 3D tensor: + `(batch_size, new_time_length, input_channels * freq_channels)` + If `data_format=="channels_first" and expand_dims`, a 4D tensor: + `(batch_size, input_channels, new_time_length, freq_channels)` + If `data_format=="channels_last" and expand_dims`, a 4D tensor: + `(batch_size, new_time_length, freq_channels, input_channels)` + + where `new_time_length` depends on the padding, and `freq_channels` is + the number of FFT bins `(fft_length // 2 + 1)`. + """ + + def __init__( + self, + mode="log", + frame_length=256, + frame_step=None, + fft_length=None, + window="hann", + periodic=False, + scaling="density", + padding="valid", + expand_dims=False, + data_format=None, + **kwargs, + ): + if frame_step is not None and ( + frame_step > frame_length or frame_step < 1 + ): + raise ValueError( + "`frame_step` should be a positive integer not greater than " + f"`frame_length`. Recieved frame_step={frame_step}, " + f"frame_length={frame_length}" + ) + + if fft_length is not None and fft_length < frame_length: + raise ValueError( + "`fft_length` should be not less than `frame_length`. " + f"Recieved fft_length={fft_length}, frame_length={frame_length}" + ) + + if fft_length is not None and (fft_length & -fft_length) != fft_length: + warnings.warn( + "`fft_length` is recommended to be a power of two. " + f"Received fft_length={fft_length}" + ) + + all_modes = ["log", "magnitude", "psd", "real", "imag", "angle", "stft"] + + if mode not in all_modes: + raise ValueError( + "Output mode is invalid, it must be one of " + f"{', '.join(all_modes)}. Received: mode={mode}" + ) + + if scaling is not None and scaling not in ["density", "spectrum"]: + raise ValueError( + "Scaling is invalid, it must be `None`, 'density' " + f"or 'spectrum'. Received scaling={scaling}" + ) + + if padding not in ["valid", "same"]: + raise ValueError( + "Padding is invalid, it should be 'valid', 'same'. " + f"Received: padding={padding}" + ) + + if isinstance(window, str): + # throws an exception for invalid window function + scipy.signal.get_window(window, 1) + + super().__init__(**kwargs) + + self.mode = mode + + self.frame_length = frame_length + self.frame_step = frame_step + self._frame_step = frame_step or self.frame_length // 2 + self.fft_length = fft_length + self._fft_length = fft_length or ( + 2 ** int(math.ceil(math.log2(frame_length))) + ) + + self.window = window + self.periodic = periodic + self.scaling = scaling + self.padding = padding + self.expand_dims = expand_dims + self.data_format = backend.standardize_data_format(data_format) + self.input_spec = layers.input_spec.InputSpec(ndim=3) + + def build(self, input_shape): + shape = (self.frame_length, 1, self._fft_length // 2 + 1) + + if self.mode != "imag": + self.real_kernel = self.add_weight( + name="real_kernel", + shape=shape, + initializer=initializers.STFTInitializer( + "real", self.window, self.scaling, self.periodic + ), + ) + if self.mode != "real": + self.imag_kernel = self.add_weight( + name="imag_kernel", + shape=shape, + initializer=initializers.STFTInitializer( + "imag", self.window, self.scaling, self.periodic + ), + ) + self.built = True + + def _adjust_shapes(self, outputs): + _, channels, freq_channels, time_seq = outputs.shape + batch_size = -1 + if self.data_format == "channels_last": + if self.expand_dims: + outputs = ops.transpose(outputs, [0, 3, 2, 1]) + # [batch_size, time_seq, freq_channels, input_channels] + else: + outputs = ops.reshape( + outputs, + [batch_size, channels * freq_channels, time_seq], + ) + # [batch_size, input_channels * freq_channels, time_seq] + outputs = ops.transpose(outputs, [0, 2, 1]) + else: + if self.expand_dims: + outputs = ops.transpose(outputs, [0, 1, 3, 2]) + # [batch_size, channels, time_seq, freq_channels] + else: + outputs = ops.reshape( + outputs, + [batch_size, channels * freq_channels, time_seq], + ) + return outputs + + def _apply_conv(self, inputs, kernel): + if self.data_format == "channels_last": + _, time_seq, channels = inputs.shape + inputs = ops.transpose(inputs, [0, 2, 1]) + inputs = ops.reshape(inputs, [-1, time_seq, 1]) + else: + _, channels, time_seq = inputs.shape + inputs = ops.reshape(inputs, [-1, 1, time_seq]) + + outputs = ops.conv( + inputs, + ops.cast(kernel, backend.standardize_dtype(inputs.dtype)), + padding=self.padding, + strides=self._frame_step, + data_format=self.data_format, + ) + batch_size = -1 + if self.data_format == "channels_last": + _, time_seq, freq_channels = outputs.shape + outputs = ops.transpose(outputs, [0, 2, 1]) + outputs = ops.reshape( + outputs, + [batch_size, channels, freq_channels, time_seq], + ) + else: + _, freq_channels, time_seq = outputs.shape + outputs = ops.reshape( + outputs, + [batch_size, channels, freq_channels, time_seq], + ) + return outputs + + def call(self, inputs): + dtype = inputs.dtype + if backend.standardize_dtype(dtype) not in { + "float16", + "float32", + "float64", + }: + raise TypeError( + "Invalid input type. Expected `float16`, `float32` or " + f"`float64`. Received: input type={dtype}" + ) + + real_signal = None + imag_signal = None + power = None + + if self.mode != "imag": + real_signal = self._apply_conv(inputs, self.real_kernel) + if self.mode != "real": + imag_signal = self._apply_conv(inputs, self.imag_kernel) + + if self.mode == "real": + return self._adjust_shapes(real_signal) + elif self.mode == "imag": + return self._adjust_shapes(imag_signal) + elif self.mode == "angle": + return self._adjust_shapes(ops.arctan2(imag_signal, real_signal)) + elif self.mode == "stft": + return self._adjust_shapes( + ops.concatenate([real_signal, imag_signal], axis=2) + ) + else: + power = ops.square(real_signal) + ops.square(imag_signal) + + if self.mode == "psd": + return self._adjust_shapes( + power + + ops.pad( + power[:, :, 1:-1, :], [[0, 0], [0, 0], [1, 1], [0, 0]] + ) + ) + linear_stft = self._adjust_shapes( + ops.sqrt(ops.maximum(power, backend.epsilon())) + ) + + if self.mode == "magnitude": + return linear_stft + else: + return ops.log(ops.maximum(linear_stft, backend.epsilon())) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + channels = input_shape[-1] + else: + channels = input_shape[1] + freq_channels = self._fft_length // 2 + 1 + if self.mode == "stft": + freq_channels *= 2 + shape = ops.operation_utils.compute_conv_output_shape( + input_shape, + freq_channels * channels, + (self.frame_length,), + strides=self._frame_step, + padding=self.padding, + data_format=self.data_format, + ) + if self.data_format == "channels_last": + batch_size, time_seq, _ = shape + else: + batch_size, _, time_seq = shape + if self.expand_dims: + if self.data_format == "channels_last": + return (batch_size, time_seq, freq_channels, channels) + else: + return (batch_size, channels, time_seq, freq_channels) + return shape + + def get_config(self): + config = super().get_config() + config.update( + { + "mode": self.mode, + "frame_length": self.frame_length, + "frame_step": self.frame_step, + "fft_length": self.fft_length, + "window": self.window, + "periodic": self.periodic, + "scaling": self.scaling, + "padding": self.padding, + "data_format": self.data_format, + "expand_dims": self.expand_dims, + } + ) + return config diff --git a/keras/src/layers/preprocessing/stft_spectrogram_test.py b/keras/src/layers/preprocessing/stft_spectrogram_test.py new file mode 100644 index 000000000000..9178b191d46d --- /dev/null +++ b/keras/src/layers/preprocessing/stft_spectrogram_test.py @@ -0,0 +1,377 @@ +import numpy as np +import pytest +import scipy.signal +import tensorflow as tf + +from keras import Input +from keras import Sequential +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class TestSpectrogram(testing.TestCase): + + DTYPE = "float32" if backend.backend() == "torch" else "float64" + + @staticmethod + def _calc_spectrograms( + x, mode, scaling, window, periodic, frame_length, frame_step, fft_length + ): + data_format = backend.image_data_format() + input_shape = (None, 1) if data_format == "channels_last" else (1, None) + + layer = Sequential( + [ + Input(shape=input_shape, dtype=TestSpectrogram.DTYPE), + layers.STFTSpectrogram( + mode=mode, + frame_length=frame_length, + frame_step=frame_step, + fft_length=fft_length, + window=window, + scaling=scaling, + periodic=periodic, + dtype=TestSpectrogram.DTYPE, + ), + ] + ) + if data_format == "channels_first": + y = layer.predict(np.transpose(x, [0, 2, 1]), verbose=0) + y = np.transpose(y, [0, 2, 1]) + else: + y = layer.predict(x, verbose=0) + + window_arr = scipy.signal.get_window(window, frame_length, periodic) + _, _, spec = scipy.signal.spectrogram( + x[..., 0].astype(TestSpectrogram.DTYPE), + window=window_arr.astype(TestSpectrogram.DTYPE), + nperseg=frame_length, + noverlap=frame_length - frame_step, + mode=mode, + scaling=scaling, + detrend=False, + nfft=fft_length, + ) + y_true = np.transpose(spec, [0, 2, 1]) + return y_true, y + + @pytest.mark.requires_trainable_backend + def test_spectrogram_channels_broadcasting(self): + rnd = np.random.RandomState(41) + audio = rnd.uniform(-1, 1, size=(3, 16000, 7)) + + layer_last = Sequential( + [ + Input(shape=(None, 7), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + layer_single = Sequential( + [ + Input(shape=(None, 1), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + + layer_expand = Sequential( + [ + Input(shape=(None, 7), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", + dtype=self.DTYPE, + data_format="channels_last", + expand_dims=True, + ), + ] + ) + + y_last = layer_last.predict(audio, verbose=0) + y_expanded = layer_expand.predict(audio, verbose=0) + y_singles = [ + layer_single.predict(audio[..., i : i + 1], verbose=0) + for i in range(audio.shape[-1]) + ] + + self.assertAllClose(y_last, np.concatenate(y_singles, axis=-1)) + self.assertAllClose(y_expanded, np.stack(y_singles, axis=-1)) + + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="TF doesn't support channels_first", + ) + @pytest.mark.requires_trainable_backend + def test_spectrogram_channels_first(self): + + rnd = np.random.RandomState(41) + audio = rnd.uniform(-1, 1, size=(3, 16000, 7)) + + layer_first = Sequential( + [ + Input(shape=(7, None), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_first" + ), + ] + ) + layer_last = Sequential( + [ + Input(shape=(None, 7), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + layer_single = Sequential( + [ + Input(shape=(None, 1), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + layer_expand = Sequential( + [ + Input(shape=(7, None), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", + dtype=self.DTYPE, + data_format="channels_first", + expand_dims=True, + ), + ] + ) + + y_singles = [ + layer_single.predict(audio[..., i : i + 1], verbose=0) + for i in range(audio.shape[-1]) + ] + y_expanded = layer_expand.predict( + np.transpose(audio, [0, 2, 1]), verbose=0 + ) + y_last = layer_last.predict(audio, verbose=0) + y_first = layer_first.predict(np.transpose(audio, [0, 2, 1]), verbose=0) + self.assertAllClose(np.transpose(y_first, [0, 2, 1]), y_last) + self.assertAllClose(y_expanded, np.stack(y_singles, axis=1)) + self.assertAllClose( + y_first, + np.transpose(np.concatenate(y_singles, axis=-1), [0, 2, 1]), + ) + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 10, + "fft_length": 512, + "trainable": False, + "padding": "same", + "expand_dims": True, + "data_format": "channels_first", + }, + input_shape=(2, 3, 160000), + expected_output_shape=(2, 3, 160000 // 10, 257), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @pytest.mark.requires_trainable_backend + def test_spectrogram_basics(self): + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 500, + "frame_step": 25, + "fft_length": 1024, + "mode": "stft", + "data_format": "channels_last", + }, + input_shape=(2, 16000, 1), + expected_output_shape=(2, 15500 // 25 + 1, 513 * 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 71, + "fft_length": 4096, + "mode": "real", + "data_format": "channels_last", + }, + input_shape=(2, 160000, 1), + expected_output_shape=(2, 159850 // 71 + 1, 2049), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 43, + "fft_length": 512, + "mode": "imag", + "padding": "same", + "data_format": "channels_last", + }, + input_shape=(2, 160000, 1), + expected_output_shape=(2, 160000 // 43 + 1, 257), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 10, + "fft_length": 512, + "trainable": False, + "padding": "same", + "data_format": "channels_last", + }, + input_shape=(2, 160000, 3), + expected_output_shape=(2, 160000 // 10, 257 * 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 10, + "fft_length": 512, + "trainable": False, + "padding": "same", + "expand_dims": True, + "data_format": "channels_last", + }, + input_shape=(2, 160000, 3), + expected_output_shape=(2, 160000 // 10, 257, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @pytest.mark.requires_trainable_backend + def test_spectrogram_error(self): + rnd = np.random.RandomState(41) + x = rnd.uniform(low=-1, high=1, size=(4, 160000, 1)).astype(self.DTYPE) + names = [ + "scaling", + "window", + "periodic", + "frame_length", + "frame_step", + "fft_length", + ] + for args in [ + ("density", "hann", False, 512, 256, 1024), + ("spectrum", "blackman", True, 512, 32, 1024), + ("spectrum", "hamming", True, 256, 192, 512), + ("spectrum", "tukey", False, 512, 128, 512), + ("density", "hamming", True, 256, 256, 256), + ("density", "hann", True, 256, 128, 256), + ]: + init_args = dict(zip(names, args)) + + tol_kwargs = {"atol": 5e-4, "rtol": 1e-6} + + init_args["mode"] = "magnitude" + y_true, y = self._calc_spectrograms(x, **init_args) + self.assertEqual(np.shape(y_true), np.shape(y)) + self.assertAllClose(y_true, y, **tol_kwargs) + + init_args["mode"] = "psd" + y_true, y = self._calc_spectrograms(x, **init_args) + self.assertEqual(np.shape(y_true), np.shape(y)) + self.assertAllClose(y_true, y, **tol_kwargs) + + init_args["mode"] = "angle" + y_true, y = self._calc_spectrograms(x, **init_args) + + pi = np.arccos(np.float128(-1)).astype(y_true.dtype) + mask = np.isclose(y, y_true, **tol_kwargs) + mask |= np.isclose(y + 2 * pi, y_true, **tol_kwargs) + mask |= np.isclose(y - 2 * pi, y_true, **tol_kwargs) + mask |= np.isclose(np.cos(y), np.cos(y_true), **tol_kwargs) + mask |= np.isclose(np.sin(y), np.sin(y_true), **tol_kwargs) + + if backend.backend() == "tensorflow": + self.assertTrue(np.all(mask)) + else: + # TODO(mostafa-mahmoud): investigate the rare cases + # of non-small error in jax and torch + self.assertLess(np.mean(~mask), 2e-4) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Requires TF tensors for TF-data module.", + ) + def test_tf_data_compatibility(self): + input_shape = (2, 16000, 1) + output_shape = (2, 16000 // 128, 358) + layer = layers.STFTSpectrogram( + frame_length=256, + frame_step=128, + fft_length=715, + padding="same", + scaling=None, + ) + input_data = np.random.random(input_shape) + ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertEqual(tuple(output.shape), output_shape) + + def test_exceptions(self): + with self.assertRaises(ValueError): + layers.STFTSpectrogram( + frame_length=256, frame_step=1024, fft_length=512 + ) + with self.assertRaises(ValueError): + layers.STFTSpectrogram( + frame_length=256, frame_step=0, fft_length=512 + ) + with self.assertRaises(ValueError): + layers.STFTSpectrogram( + frame_length=256, frame_step=32, fft_length=128 + ) + with self.assertRaises(ValueError): + layers.STFTSpectrogram(padding="mypadding") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(scaling="l2") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(mode="spectrogram") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(window="unknowable") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(scaling="l2") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(padding="divide") + with self.assertRaises(TypeError): + layers.STFTSpectrogram()( + np.random.randint(0, 255, size=(2, 16000, 1)) + )