Skip to content

Commit c685a8e

Browse files
lucasnewmanBlaizzy
andauthored
Add wav2vec2 model for Spark (#131)
* Add wav2vec2 model for Spark. * fix indexing * fix mel * fix indices --------- Co-authored-by: Prince Canuma <[email protected]>
1 parent 07cea9a commit c685a8e

File tree

6 files changed

+191
-1051
lines changed

6 files changed

+191
-1051
lines changed

mlx_audio/codec/models/vocos/mel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ def mel_to_hz(mels, mel_scale="htk"):
4141
min_log_hz = 1000.0
4242
min_log_mel = (min_log_hz - f_min) / f_sp
4343
logstep = math.log(6.4) / 27.0
44-
log_t = mels >= min_log_mel
45-
freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel))
44+
freqs = mx.where(
45+
mels >= min_log_mel,
46+
min_log_hz * mx.exp(logstep * (mels - min_log_mel)),
47+
freqs,
48+
)
4649
return freqs
4750

4851
f_max = f_max or sample_rate / 2
@@ -103,6 +106,9 @@ def _pad(x, padding, pad_mode="constant"):
103106
else:
104107
raise ValueError(f"Invalid pad_mode {pad_mode}")
105108

109+
if window.shape[0] < nfft:
110+
window = mx.pad(window, (0, nfft - window.shape[0]))
111+
106112
padding = nperseg // 2
107113
x = _pad(x, padding, pad_mode)
108114

0 commit comments

Comments
 (0)