|
3 | 3 | from keras.src.backend import standardize_dtype |
4 | 4 | from keras.src.initializers.initializer import Initializer |
5 | 5 | from keras.src.saving import serialization_lib |
| 6 | +from keras.src.utils.module_utils import scipy |
6 | 7 |
|
7 | 8 |
|
8 | 9 | @keras_export(["keras.initializers.Constant", "keras.initializers.constant"]) |
@@ -151,3 +152,120 @@ def __call__(self, shape, dtype=None): |
151 | 152 | ) |
152 | 153 | dtype = standardize_dtype(dtype) |
153 | 154 | return self.gain * ops.eye(*shape, dtype=dtype) |
| 155 | + |
| 156 | + |
| 157 | +@keras_export(["keras.initializers.STFTInitializer"]) |
| 158 | +class STFTInitializer(Initializer): |
| 159 | + """Initializer of Conv kernels for Short-term Fourier Transformation (STFT). |
| 160 | +
|
| 161 | + Since the formula involves complex numbers, this class compute either the |
| 162 | + real or the imaginary components of the final output. |
| 163 | +
|
| 164 | + Additionally, this initializer supports windowing functions across the time |
| 165 | + dimension as commonly used in STFT. Windowing functions from the module |
| 166 | + `scipy.signal.windows` are supported, including the common `hann` and |
| 167 | + `hamming` windowing functions. This layer supports periodic windows and |
| 168 | + scaling-based normalization. |
| 169 | +
|
| 170 | + This is primarly intended for use in the `STFTSpectrogram` layer. |
| 171 | +
|
| 172 | + Examples: |
| 173 | +
|
| 174 | + >>> # Standalone usage: |
| 175 | + >>> initializer = STFTInitializer("real", "hann", "density", False) |
| 176 | + >>> values = initializer(shape=(128, 1, 513)) |
| 177 | +
|
| 178 | + Args: |
| 179 | + side: String, `"real"` or `"imag"` deciding if the kernel will compute |
| 180 | + the real side or the imaginary side of the output. |
| 181 | + window: String for the name of the windowing function in the |
| 182 | + `scipy.signal.windows` module, or array_like for the window values, |
| 183 | + or `None` for no windowing. |
| 184 | + scaling: String, `"density"` or `"spectrum"` for scaling of the window |
| 185 | + for normalization, either L2 or L1 normalization. |
| 186 | + `None` for no scaling. |
| 187 | + periodic: Boolean, if True, the window function will be treated as |
| 188 | + periodic. Defaults to `False`. |
| 189 | + """ |
| 190 | + |
| 191 | + def __init__(self, side, window="hann", scaling="density", periodic=False): |
| 192 | + if side not in ["real", "imag"]: |
| 193 | + raise ValueError(f"side should be 'real' or 'imag', not {side}") |
| 194 | + if isinstance(window, str): |
| 195 | + # throws an exception for invalid window function |
| 196 | + scipy.signal.get_window(window, 1) |
| 197 | + if scaling is not None and scaling not in ["density", "spectrum"]: |
| 198 | + raise ValueError( |
| 199 | + "Scaling is invalid, it must be `None`, 'density' " |
| 200 | + f"or 'spectrum'. Received scaling={scaling}" |
| 201 | + ) |
| 202 | + self.side = side |
| 203 | + self.window = window |
| 204 | + self.scaling = scaling |
| 205 | + self.periodic = periodic |
| 206 | + |
| 207 | + def __call__(self, shape, dtype=None): |
| 208 | + """Returns a tensor object initialized as specified by the initializer. |
| 209 | +
|
| 210 | + The shape is assumed to be `(T, 1, F // 2 + 1)`, where `T` is the size |
| 211 | + of the given window, and `F` is the number of frequency bands. Only half |
| 212 | + the frequency bands are used, which is a common practice in STFT, |
| 213 | + because the second half are the conjugates of the first half in |
| 214 | + a reversed order. |
| 215 | +
|
| 216 | + Args: |
| 217 | + shape: Shape of the tensor. |
| 218 | + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes |
| 219 | + are supported. If not specified, `keras.backend.floatx()` |
| 220 | + is used, which default to `float32` unless you configured it |
| 221 | + otherwise (via `keras.backend.set_floatx(float_dtype)`). |
| 222 | + """ |
| 223 | + dtype = standardize_dtype(dtype) |
| 224 | + frame_length, input_channels, fft_length = shape |
| 225 | + |
| 226 | + win = None |
| 227 | + scaling = 1 |
| 228 | + if self.window is not None: |
| 229 | + win = self.window |
| 230 | + if isinstance(win, str): |
| 231 | + # Using SciPy since it provides more windowing functions, |
| 232 | + # easier to be compatible with multiple backends. |
| 233 | + win = scipy.signal.get_window(win, frame_length, self.periodic) |
| 234 | + win = ops.convert_to_tensor(win, dtype=dtype) |
| 235 | + if len(win.shape) != 1 or win.shape[-1] != frame_length: |
| 236 | + raise ValueError( |
| 237 | + "The shape of `window` must be equal to [frame_length]." |
| 238 | + f"Received: window shape={win.shape}" |
| 239 | + ) |
| 240 | + win = ops.reshape(win, [frame_length, 1, 1]) |
| 241 | + if self.scaling == "density": |
| 242 | + scaling = ops.sqrt(ops.sum(ops.square(win))) |
| 243 | + elif self.scaling == "spectrum": |
| 244 | + scaling = ops.sum(ops.abs(win)) |
| 245 | + |
| 246 | + _fft_length = (fft_length - 1) * 2 |
| 247 | + freq = ( |
| 248 | + ops.reshape(ops.arange(fft_length, dtype=dtype), (1, 1, fft_length)) |
| 249 | + / _fft_length |
| 250 | + ) |
| 251 | + time = ops.reshape( |
| 252 | + ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1) |
| 253 | + ) |
| 254 | + args = -2 * time * freq * ops.arccos(ops.cast(-1, dtype)) |
| 255 | + |
| 256 | + if self.side == "real": |
| 257 | + kernel = ops.cast(ops.cos(args), dtype) |
| 258 | + else: |
| 259 | + kernel = ops.cast(ops.sin(args), dtype) |
| 260 | + |
| 261 | + if win is not None: |
| 262 | + kernel = kernel * win / scaling |
| 263 | + return kernel |
| 264 | + |
| 265 | + def get_config(self): |
| 266 | + return { |
| 267 | + "side": self.side, |
| 268 | + "window": self.window, |
| 269 | + "periodic": self.periodic, |
| 270 | + "scaling": self.scaling, |
| 271 | + } |
0 commit comments