Skip to content

Commit 7a08ba5

Browse files
authored
Merge pull request #1: Added a new Spectrogram layer based on Conv1D operations, supporting GPU-parallelization and fine-tuning
Merged from original PR #20313 Original: keras-team/keras#20313
2 parents 6070e4c + 2669d99 commit 7a08ba5

File tree

10 files changed

+949
-0
lines changed

10 files changed

+949
-0
lines changed

keras/api/_tf_keras/keras/initializers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src.initializers.constant_initializers import Identity as identity
1717
from keras.src.initializers.constant_initializers import Ones
1818
from keras.src.initializers.constant_initializers import Ones as ones
19+
from keras.src.initializers.constant_initializers import STFTInitializer
1920
from keras.src.initializers.constant_initializers import Zeros
2021
from keras.src.initializers.constant_initializers import Zeros as zeros
2122
from keras.src.initializers.initializer import Initializer

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
from keras.src.layers.preprocessing.normalization import Normalization
176176
from keras.src.layers.preprocessing.pipeline import Pipeline
177177
from keras.src.layers.preprocessing.rescaling import Rescaling
178+
from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram
178179
from keras.src.layers.preprocessing.string_lookup import StringLookup
179180
from keras.src.layers.preprocessing.text_vectorization import TextVectorization
180181
from keras.src.layers.regularization.activity_regularization import (

keras/api/initializers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src.initializers.constant_initializers import Identity as identity
1717
from keras.src.initializers.constant_initializers import Ones
1818
from keras.src.initializers.constant_initializers import Ones as ones
19+
from keras.src.initializers.constant_initializers import STFTInitializer
1920
from keras.src.initializers.constant_initializers import Zeros
2021
from keras.src.initializers.constant_initializers import Zeros as zeros
2122
from keras.src.initializers.initializer import Initializer

keras/api/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
from keras.src.layers.preprocessing.normalization import Normalization
176176
from keras.src.layers.preprocessing.pipeline import Pipeline
177177
from keras.src.layers.preprocessing.rescaling import Rescaling
178+
from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram
178179
from keras.src.layers.preprocessing.string_lookup import StringLookup
179180
from keras.src.layers.preprocessing.text_vectorization import TextVectorization
180181
from keras.src.layers.regularization.activity_regularization import (

keras/src/initializers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from keras.src.initializers.constant_initializers import Constant
55
from keras.src.initializers.constant_initializers import Identity
66
from keras.src.initializers.constant_initializers import Ones
7+
from keras.src.initializers.constant_initializers import STFTInitializer
78
from keras.src.initializers.constant_initializers import Zeros
89
from keras.src.initializers.initializer import Initializer
910
from keras.src.initializers.random_initializers import GlorotNormal
@@ -25,6 +26,7 @@
2526
Constant,
2627
Identity,
2728
Ones,
29+
STFTInitializer,
2830
Zeros,
2931
GlorotNormal,
3032
GlorotUniform,

keras/src/initializers/constant_initializers.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from keras.src.backend import standardize_dtype
44
from keras.src.initializers.initializer import Initializer
55
from keras.src.saving import serialization_lib
6+
from keras.src.utils.module_utils import scipy
67

78

89
@keras_export(["keras.initializers.Constant", "keras.initializers.constant"])
@@ -151,3 +152,120 @@ def __call__(self, shape, dtype=None):
151152
)
152153
dtype = standardize_dtype(dtype)
153154
return self.gain * ops.eye(*shape, dtype=dtype)
155+
156+
157+
@keras_export(["keras.initializers.STFTInitializer"])
158+
class STFTInitializer(Initializer):
159+
"""Initializer of Conv kernels for Short-term Fourier Transformation (STFT).
160+
161+
Since the formula involves complex numbers, this class compute either the
162+
real or the imaginary components of the final output.
163+
164+
Additionally, this initializer supports windowing functions across the time
165+
dimension as commonly used in STFT. Windowing functions from the module
166+
`scipy.signal.windows` are supported, including the common `hann` and
167+
`hamming` windowing functions. This layer supports periodic windows and
168+
scaling-based normalization.
169+
170+
This is primarly intended for use in the `STFTSpectrogram` layer.
171+
172+
Examples:
173+
174+
>>> # Standalone usage:
175+
>>> initializer = STFTInitializer("real", "hann", "density", False)
176+
>>> values = initializer(shape=(128, 1, 513))
177+
178+
Args:
179+
side: String, `"real"` or `"imag"` deciding if the kernel will compute
180+
the real side or the imaginary side of the output.
181+
window: String for the name of the windowing function in the
182+
`scipy.signal.windows` module, or array_like for the window values,
183+
or `None` for no windowing.
184+
scaling: String, `"density"` or `"spectrum"` for scaling of the window
185+
for normalization, either L2 or L1 normalization.
186+
`None` for no scaling.
187+
periodic: Boolean, if True, the window function will be treated as
188+
periodic. Defaults to `False`.
189+
"""
190+
191+
def __init__(self, side, window="hann", scaling="density", periodic=False):
192+
if side not in ["real", "imag"]:
193+
raise ValueError(f"side should be 'real' or 'imag', not {side}")
194+
if isinstance(window, str):
195+
# throws an exception for invalid window function
196+
scipy.signal.get_window(window, 1)
197+
if scaling is not None and scaling not in ["density", "spectrum"]:
198+
raise ValueError(
199+
"Scaling is invalid, it must be `None`, 'density' "
200+
f"or 'spectrum'. Received scaling={scaling}"
201+
)
202+
self.side = side
203+
self.window = window
204+
self.scaling = scaling
205+
self.periodic = periodic
206+
207+
def __call__(self, shape, dtype=None):
208+
"""Returns a tensor object initialized as specified by the initializer.
209+
210+
The shape is assumed to be `(T, 1, F // 2 + 1)`, where `T` is the size
211+
of the given window, and `F` is the number of frequency bands. Only half
212+
the frequency bands are used, which is a common practice in STFT,
213+
because the second half are the conjugates of the first half in
214+
a reversed order.
215+
216+
Args:
217+
shape: Shape of the tensor.
218+
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
219+
are supported. If not specified, `keras.backend.floatx()`
220+
is used, which default to `float32` unless you configured it
221+
otherwise (via `keras.backend.set_floatx(float_dtype)`).
222+
"""
223+
dtype = standardize_dtype(dtype)
224+
frame_length, input_channels, fft_length = shape
225+
226+
win = None
227+
scaling = 1
228+
if self.window is not None:
229+
win = self.window
230+
if isinstance(win, str):
231+
# Using SciPy since it provides more windowing functions,
232+
# easier to be compatible with multiple backends.
233+
win = scipy.signal.get_window(win, frame_length, self.periodic)
234+
win = ops.convert_to_tensor(win, dtype=dtype)
235+
if len(win.shape) != 1 or win.shape[-1] != frame_length:
236+
raise ValueError(
237+
"The shape of `window` must be equal to [frame_length]."
238+
f"Received: window shape={win.shape}"
239+
)
240+
win = ops.reshape(win, [frame_length, 1, 1])
241+
if self.scaling == "density":
242+
scaling = ops.sqrt(ops.sum(ops.square(win)))
243+
elif self.scaling == "spectrum":
244+
scaling = ops.sum(ops.abs(win))
245+
246+
_fft_length = (fft_length - 1) * 2
247+
freq = (
248+
ops.reshape(ops.arange(fft_length, dtype=dtype), (1, 1, fft_length))
249+
/ _fft_length
250+
)
251+
time = ops.reshape(
252+
ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1)
253+
)
254+
args = -2 * time * freq * ops.arccos(ops.cast(-1, dtype))
255+
256+
if self.side == "real":
257+
kernel = ops.cast(ops.cos(args), dtype)
258+
else:
259+
kernel = ops.cast(ops.sin(args), dtype)
260+
261+
if win is not None:
262+
kernel = kernel * win / scaling
263+
return kernel
264+
265+
def get_config(self):
266+
return {
267+
"side": self.side,
268+
"window": self.window,
269+
"periodic": self.periodic,
270+
"scaling": self.scaling,
271+
}

keras/src/initializers/constant_initializers_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy.signal
23

34
from keras.src import backend
45
from keras.src import initializers
@@ -67,3 +68,65 @@ def test_identity_initializer(self):
6768
self.assertAllClose(np_values, np.eye(*shape) * gain)
6869

6970
self.run_class_serialization_test(initializer)
71+
72+
def test_stft_initializer(self):
73+
shape = (256, 1, 513)
74+
time_range = np.arange(256).reshape((-1, 1, 1))
75+
freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1))
76+
pi = np.arccos(np.float64(-1))
77+
args = -2 * pi * time_range * freq_range
78+
79+
tol_kwargs = {}
80+
if backend.backend() == "jax":
81+
# TODO(mostafa-mahmoud): investigate the cases
82+
# of non-small error in jax and torch
83+
tol_kwargs = {"atol": 1e-4, "rtol": 1e-6}
84+
85+
initializer = initializers.STFTInitializer("real", None)
86+
values = backend.convert_to_numpy(initializer(shape))
87+
self.assertAllClose(np.cos(args), values, atol=1e-4)
88+
self.run_class_serialization_test(initializer)
89+
90+
initializer = initializers.STFTInitializer(
91+
"real",
92+
"hamming",
93+
None,
94+
True,
95+
)
96+
window = scipy.signal.windows.get_window("hamming", 256, True)
97+
window = window.astype("float64").reshape((-1, 1, 1))
98+
values = backend.convert_to_numpy(initializer(shape, "float64"))
99+
self.assertAllClose(np.cos(args) * window, values, **tol_kwargs)
100+
self.run_class_serialization_test(initializer)
101+
102+
initializer = initializers.STFTInitializer(
103+
"imag",
104+
"tukey",
105+
"density",
106+
False,
107+
)
108+
window = scipy.signal.windows.get_window("tukey", 256, False)
109+
window = window.astype("float64").reshape((-1, 1, 1))
110+
window = window / np.sqrt(np.sum(window**2))
111+
values = backend.convert_to_numpy(initializer(shape, "float64"))
112+
self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)
113+
self.run_class_serialization_test(initializer)
114+
115+
initializer = initializers.STFTInitializer(
116+
"imag",
117+
list(range(1, 257)),
118+
"spectrum",
119+
)
120+
window = np.arange(1, 257)
121+
window = window.astype("float64").reshape((-1, 1, 1))
122+
window = window / np.sum(window)
123+
values = backend.convert_to_numpy(initializer(shape, "float64"))
124+
self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)
125+
self.run_class_serialization_test(initializer)
126+
127+
with self.assertRaises(ValueError):
128+
initializers.STFTInitializer("imaginary")
129+
with self.assertRaises(ValueError):
130+
initializers.STFTInitializer("real", scaling="l2")
131+
with self.assertRaises(ValueError):
132+
initializers.STFTInitializer("real", window="unknown")

keras/src/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
from keras.src.layers.preprocessing.normalization import Normalization
120120
from keras.src.layers.preprocessing.pipeline import Pipeline
121121
from keras.src.layers.preprocessing.rescaling import Rescaling
122+
from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram
122123
from keras.src.layers.preprocessing.string_lookup import StringLookup
123124
from keras.src.layers.preprocessing.text_vectorization import TextVectorization
124125
from keras.src.layers.regularization.activity_regularization import (

0 commit comments

Comments
 (0)