Skip to content

Commit 15104cc

Browse files
committed
Refactor: Integrate Whisper for speech unit extraction
This commit revamps the speech resynthesis pipeline by replacing the mHuBERT-based self-supervised unit extraction (via textlesslib) with a supervised Whisper-based approach. This change significantly impacts data processing, model configuration, and training/evaluation workflows. Key changes include: - **Core Unit Extraction:** - Removed `textlesslib` dependency and its associated `SpeechEncoder` (mHuBERT + K-means). - Integrated `WhisperFeatureExtractor` and `WhisperEncoder` from `src.flow_matching.utils.whisper` for supervised discrete unit extraction. - Updated Python version from 3.9 to 3.10 and added `faiss-gpu` to the environment. - **Dataset and Configuration:** - Default dataset changed to `ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units`. - Vocabulary size updated from 2000 to 4096 units. - Replaced mHuBERT-based config files with `whisper-large-v3-4096-bigvgan.yaml`, reflecting new model parameters (e.g., `dim_cond_emb`, tokenizer settings) and ASR model (`openai/whisper-large-v3`). - **Pipeline Simplification:** - Removed `tokenize` and `extract_features` stages from data preprocessing. - Removed the `train_hifigan` stage and the `ConditionalFlowMatchingWithHifiGan` model and its configuration. - Removed the dedicated `src/flow_matching/eval.py` module and the `evaluate` task in `main_resynth.py`. - Removed custom ASR evaluation utilities (`src/flow_matching/utils/phi/`) and textless utilities (`src/flow_matching/utils/textless.py`). - Removed the entire `src/hifigan/` directory. - **Data Handling and Training:** - Replaced `UnitDataset` with direct loading from Hugging Face datasets using `load_dataset` and a new `get_collate_fn` in `src/flow_matching/data.py`. - Updated validation in `src/flow_matching/train.py` to use `transformers.pipeline` with `openai/whisper-large-v3` for ASR and `processor.tokenizer.normalize` for text normalization. - Updated unit embedding in `train_flow_matching` to use `WhisperEncoder.from_pretrained(...).quantizer`. - **Usage and Demo:** - Updated `README.md` and `demo.ipynb` to reflect the new setup, usage patterns, and Whisper integration for unit encoding. - Updated `scripts/setup.sh` to remove `textlesslib` cloning. This refactoring aims to leverage supervised Whisper models for potentially higher-quality unit extraction and simplifies the overall codebase by relying more on Hugging Face's ecosystem for common ASR and data handling tasks.
1 parent 1a54261 commit 15104cc

28 files changed

Lines changed: 159 additions & 4020 deletions

README.md

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,36 @@
11
# Speech Resynthesis and Language Modeling with Flow Matching and Llama
22

33
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
4-
[![Python](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org)
4+
[![Python](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org)
55
[![model](https://img.shields.io/badge/%F0%9F%A4%97-Models-blue)](https://huggingface.co/ryota-komatsu/flow_matching_with_bigvgan)
6-
[![dataset](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-blue)](https://huggingface.co/datasets/ryota-komatsu/libritts-r-mhubert-2000units)
6+
[![dataset](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-blue)](https://huggingface.co/datasets/ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units)
77

88
## Setup
99

1010
```shell
1111
sudo apt install git-lfs # for UTMOS
1212

13-
conda create -y -n py39 python=3.9.21 pip=24.0
14-
conda activate py39
15-
pip install -r requirements/requirements.txt
13+
conda create -y -n py310 -c pytorch -c nvidia -c conda-forge python=3.10.17 pip=24.0 faiss-gpu=1.10.0
14+
conda activate py310
15+
pip install -r requirements.txt
1616
pip install flash-attn --no-build-isolation # optional
1717

18-
sh scripts/setup.sh # download textlesslib and UTMOS
19-
20-
cd src/textlesslib
21-
pip install -e .
22-
cd -
18+
sh scripts/setup.sh # download UTMOS
2319
```
2420

25-
## Usage: sampling multi-speaker speech from self-supervised discrete units
21+
## Usage: sampling multi-speaker speech from supervised discrete units
2622

2723
```python
2824
import torchaudio
29-
from textless.data.speech_encoder import SpeechEncoder
3025

3126
from src.flow_matching.models import ConditionalFlowMatchingWithBigVGan
27+
from src.flow_matching.utils.whisper import WhisperFeatureExtractor, WhisperEncoder
3228

3329
wav_path = "/path/to/wav"
3430

35-
encoder = SpeechEncoder.by_name(
36-
dense_model_name="mhubert-base-vp_mls_cv_8lang",
37-
quantizer_model_name="kmeans-expresso",
38-
vocab_size=2000,
39-
deduplicate=False,
40-
need_f0=False,
41-
).cuda()
31+
# load model and processor
32+
feature_extractor = WhisperFeatureExtractor.from_pretrained("ryota-komatsu/whisper-large-v3-tokenizer")
33+
encoder = WhisperEncoder.from_pretrained("ryota-komatsu/whisper-large-v3-tokenizer").cuda()
4234

4335
# download a pretrained model from hugging face hub
4436
decoder = ConditionalFlowMatchingWithBigVGan.from_pretrained("ryota-komatsu/flow_matching_with_bigvgan").cuda()
@@ -47,8 +39,16 @@ decoder = ConditionalFlowMatchingWithBigVGan.from_pretrained("ryota-komatsu/flow
4739
waveform, sr = torchaudio.load(wav_path)
4840
waveform = torchaudio.functional.resample(waveform, sr, 16000)
4941

42+
input_features = feature_extractor(
43+
waveform.squeeze(0).numpy(),
44+
return_tensors="pt",
45+
sampling_rate=16000,
46+
device="cuda",
47+
padding="do_not_pad",
48+
).input_features.to("cuda")
49+
5050
# encode a waveform into pseudo-phonetic units
51-
units = encoder(waveform.cuda())["units"]
51+
units = encoder(input_features, out_layer=15)
5252
units = units.unsqueeze(0) + 1 # 0: pad
5353

5454
# resynthesis
@@ -105,7 +105,7 @@ Jupyter notebook demo is found [here](demo.ipynb).
105105

106106
## Data Preparation
107107

108-
If you already have LibriTTS-R, you can use it by editing [a config file](configs/unit2speech/mhubert-expresso-2000.yaml#L6);
108+
If you already have LibriTTS-R, you can use it by editing [a config file](configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml#L7);
109109
```yaml
110110
dataset:
111111
wav_dir_orig: "/path/to/LibriTTS-R" # ${dataset.wav_dir_orig}/train-clean-100, train-clean-360, ...
@@ -129,18 +129,14 @@ sh scripts/download_slm21.sh # download sWUGGY and sBLIMP
129129
## Training a unit-to-speech synthesizer
130130

131131
```shell
132-
python main_resynth.py --config=configs/unit2speech/mhubert-expresso-2000.yaml
132+
python main_resynth.py --config=configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml
133133
```
134134

135135
To run only a specific stage, pass it as an argument.
136136

137137
Supported processing stages
138138
1. resample
139-
1. tokenize
140-
1. extract_features
141-
1. train_bigvgan # can be skipped when using a pretrained model
142139
1. train_flow_matching
143-
1. evaluate
144140
1. synthesize
145141

146142
```shell

configs/unit2speech/mhubert-expresso-2000-duration-prediction.yaml

Lines changed: 0 additions & 102 deletions
This file was deleted.

configs/unit2speech/mhubert-expresso-2000.yaml

Lines changed: 0 additions & 102 deletions
This file was deleted.

configs/unit2speech/mhubert-expresso-2000-bigvgan.yaml renamed to configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,14 @@ common:
22
seed: 0
33

44
dataset:
5+
name: "ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units" # https://huggingface.co/datasets
56
wav_dir: "data/LibriTTS_R_16k" # ${root}/train-clean-100, train-clean-360, ...
67
wav_dir_orig: "data/LibriTTS_R" # if wav_dir == wav_dir_orig, original wav files are overwritten with 16 kHz waveforms
7-
spectrogram_dir: "data/LibriTTS_R_16k/spectrogram" # 34GB
88
vad: false
99

1010
ext_audio: ".wav"
1111
ext_txt: ".normalized.txt"
1212

13-
# json file format
14-
# "name": {"units": List[int], "durations": List[int], "transcript": str}
15-
train_file: "data/resynth/train.json" # 354,729 samples
16-
dev_file: "data/resynth/dev.json" # 5,736 samples
17-
test_file: "data/resynth/test.json" # 4,837 samples
18-
1913
synthesis:
2014
src_dir: ${dataset.wav_dir}
2115
tgt_dir: ${dataset.wav_dir}_resynth
@@ -39,17 +33,14 @@ flow_matching:
3933
save_interval_epoch: 20
4034

4135
# inference
42-
dt: 0.0625
43-
truncation_value: 1.0 # truncation trick (https://arxiv.org/abs/1809.11096)
36+
dt: 0.1
37+
truncation_value: null # truncation trick (https://arxiv.org/abs/1809.11096)
4438

45-
# textless.data.speech_encoder.SpeechEncoder
46-
dense_model_name: "mhubert-base-vp_mls_cv_8lang"
47-
quantizer_model_name: "kmeans-expresso"
48-
vocab_size: 2000
39+
vocab_size: ${tokenizer.vocab_size}
4940

5041
# src.flow_matching.configs.ConditionalFlowMatchingConfig
5142
dim_in: 80
52-
dim_cond_emb: 768
43+
dim_cond_emb: 1280
5344
hidden_size: 256
5445
depth: 4
5546
heads: 2
@@ -113,9 +104,14 @@ vocoder:
113104
checkpoint_interval: 10000
114105
validation_interval: 10000
115106

107+
tokenizer:
108+
name: "ryota-komatsu/whisper-large-v3-tokenizer"
109+
vocab_size: 4096
110+
out_layer: 15
111+
116112
flow_matching_with_vocoder:
117113
name: "ryota-komatsu/flow_matching_with_bigvgan"
118114
batch_size: 8
119115

120116
asr:
121-
name: "microsoft/Phi-4-multimodal-instruct"
117+
name: "openai/whisper-large-v3"

0 commit comments

Comments
 (0)