Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/mel_filters.npz
Binary file not shown.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@ transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
x-transformers==1.44.4
torchdiffeq==0.2.5
openai-whisper==20240930
httpx==0.28.1
gradio==5.23.1
116 changes: 114 additions & 2 deletions tts/frontend_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn.functional as F
import whisper
import librosa
from copy import deepcopy
from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
from tts.utils.audio_utils.align import mel2token_to_dur
from subprocess import CalledProcessError, run
import numpy as np

''' Graphme to phoneme function '''
def g2p(self, text_inp):
Expand All @@ -40,7 +42,7 @@ def g2p(self, text_inp):
def align(self, wav):
with torch.inference_mode():
whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
mel = torch.FloatTensor(log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
prompt_max_frame = mel.size(2) // self.fm * self.fm
mel = mel[:, :, :prompt_max_frame]
token = torch.LongTensor([[798]]).to(self.device)
Expand Down Expand Up @@ -179,3 +181,113 @@ def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_p
"ctx_mask": ctx_mask,
"dur": mel2ph_pred,
}


def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:

np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

filters_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)

SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160

def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary

Parameters
----------
file: str
The audio file to open

sr: int
The sample rate to resample the audio if necessary

Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

def log_mel_spectrogram(
audio,
n_mels: int = 80,
padding: int = 0,
device = None,
):
"""
Compute the log-Mel spectrogram of

Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz

n_mels: int
The number of Mel-frequency filters, only 80 and 128 are supported

padding: int
Number of zero samples to pad to the right

device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT

Returns
-------
torch.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec