-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathgigaam.py
More file actions
107 lines (82 loc) · 4.06 KB
/
gigaam.py
File metadata and controls
107 lines (82 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""LogMelSpectrogram feature extractor for GigaAM models."""
import numpy as np
from ml_dtypes import bfloat16
from onnxscript import FLOAT, INT64, script
from onnxscript import opset17 as op
from preprocessors.fbanks import melscale_fbanks
from preprocessors.stft import conv_power_spectrogram, stft_conv_weights
sample_rate = 16_000
n_fft_v2 = sample_rate // 40
n_fft_v3 = sample_rate // 50
win_length_v2 = sample_rate // 40
win_length_v3 = sample_rate // 50
hop_length = sample_rate // 100
n_mels = 64
f_min = 0
f_max = 8_000
clamp_min = 1e-9
clamp_max = 1e9
melscale_fbanks_v2 = melscale_fbanks(n_fft_v2 // 2 + 1, f_min, f_max, n_mels, sample_rate).astype(np.float32)
melscale_fbanks_v3 = (
melscale_fbanks(n_fft_v3 // 2 + 1, f_min, f_max, n_mels, sample_rate).astype(bfloat16).astype(np.float32)
)
hann_window_v3 = np.hanning(win_length_v3 + 1)[:-1].astype(bfloat16).astype(np.float32)
stft_conv_weights_v2 = stft_conv_weights(np.hanning(win_length_v2 + 1)[:-1].astype(np.float32))
stft_conv_weights_v3 = stft_conv_weights(hann_window_v3)
@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models")
def GigaamPreprocessorV2(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]:
waveforms = op.Pad(
waveforms,
pads=op.Constant(value=[0, n_fft_v2 // 2, 0, n_fft_v2 // 2]),
mode="reflect",
)
hann_window = op.HannWindow(win_length_v2)
image = op.STFT(waveforms, hop_length, hann_window)
spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0)
mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v2)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max))
features_lens = waveforms_lens / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens
@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v3 models")
def GigaamPreprocessorV3(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]:
image = op.STFT(op.CastLike(waveforms, hann_window_v3), hop_length, hann_window_v3)
spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0)
mel_spectrogram = op.MatMul(op.CastLike(spectrogram, melscale_fbanks_v3), melscale_fbanks_v3)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max))
features_lens = (waveforms_lens - win_length_v3) / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens
@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models (Conv-based STFT)")
def GigaamPreprocessorV2Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]:
waveforms = op.Pad(
waveforms,
pads=op.Constant(value=[0, n_fft_v2 // 2, 0, n_fft_v2 // 2]),
mode="reflect",
)
spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v2)
mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v2)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max))
features_lens = waveforms_lens / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens
@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v3 models (Conv-based STFT)")
def GigaamPreprocessorV3Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]:
spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v3)
mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v3)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max))
features_lens = (waveforms_lens - win_length_v3) / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens