Skip to content

Commit 0d475f7

Browse files
committed
Refactor: Integrate Whisper for speech unit extraction
This commit revamps the speech resynthesis pipeline by replacing the mHuBERT-based self-supervised unit extraction (via textlesslib) with a supervised Whisper-based approach. This change significantly impacts data processing, model configuration, and training/evaluation workflows. Key changes include: - **Core Unit Extraction:** - Removed `textlesslib` dependency and its associated `SpeechEncoder` (mHuBERT + K-means). - Integrated `WhisperFeatureExtractor` and `WhisperEncoder` from `src.flow_matching.utils.whisper` for supervised discrete unit extraction. - Updated Python version from 3.9 to 3.10 and added `faiss-gpu` to the environment. - **Dataset and Configuration:** - Default dataset changed to `ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units`. - Vocabulary size updated from 2000 to 4096 units. - Replaced mHuBERT-based config files with `whisper-large-v3-4096-bigvgan.yaml`, reflecting new model parameters (e.g., `dim_cond_emb`, tokenizer settings) and ASR model (`openai/whisper-large-v3`). - **Pipeline Simplification:** - Removed `tokenize` and `extract_features` stages from data preprocessing. - Removed the `train_hifigan` stage and the `ConditionalFlowMatchingWithHifiGan` model and its configuration. - Removed the dedicated `src/flow_matching/eval.py` module and the `evaluate` task in `main_resynth.py`. - Removed custom ASR evaluation utilities (`src/flow_matching/utils/phi/`) and textless utilities (`src/flow_matching/utils/textless.py`). - Removed the entire `src/hifigan/` directory. - **Data Handling and Training:** - Replaced `UnitDataset` with direct loading from Hugging Face datasets using `load_dataset` and a new `get_collate_fn` in `src/flow_matching/data.py`. - Updated validation in `src/flow_matching/train.py` to use `transformers.pipeline` with `openai/whisper-large-v3` for ASR and `processor.tokenizer.normalize` for text normalization. - Updated unit embedding in `train_flow_matching` to use `WhisperEncoder.from_pretrained(...).quantizer`. - **Usage and Demo:** - Updated `README.md` and `demo.ipynb` to reflect the new setup, usage patterns, and Whisper integration for unit encoding. - Updated `scripts/setup.sh` to remove `textlesslib` cloning. This refactoring aims to leverage supervised Whisper models for potentially higher-quality unit extraction and simplifies the overall codebase by relying more on Hugging Face's ecosystem for common ASR and data handling tasks.
1 parent 1a54261 commit 0d475f7

30 files changed

Lines changed: 215 additions & 4066 deletions

README.md

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,36 @@
11
# Speech Resynthesis and Language Modeling with Flow Matching and Llama
22

33
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
4-
[![Python](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org)
4+
[![Python](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org)
55
[![model](https://img.shields.io/badge/%F0%9F%A4%97-Models-blue)](https://huggingface.co/ryota-komatsu/flow_matching_with_bigvgan)
6-
[![dataset](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-blue)](https://huggingface.co/datasets/ryota-komatsu/libritts-r-mhubert-2000units)
6+
[![dataset](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-blue)](https://huggingface.co/datasets/ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units)
77

88
## Setup
99

1010
```shell
1111
sudo apt install git-lfs # for UTMOS
1212

13-
conda create -y -n py39 python=3.9.21 pip=24.0
14-
conda activate py39
15-
pip install -r requirements/requirements.txt
13+
conda create -y -n py310 -c pytorch -c nvidia -c conda-forge python=3.10.17 pip=24.0 faiss-gpu=1.10.0
14+
conda activate py310
15+
pip install -r requirements.txt
1616
pip install flash-attn --no-build-isolation # optional
1717

18-
sh scripts/setup.sh # download textlesslib and UTMOS
19-
20-
cd src/textlesslib
21-
pip install -e .
22-
cd -
18+
sh scripts/setup.sh # download UTMOS
2319
```
2420

25-
## Usage: sampling multi-speaker speech from self-supervised discrete units
21+
## Usage: sampling multi-speaker speech from supervised discrete units
2622

2723
```python
2824
import torchaudio
29-
from textless.data.speech_encoder import SpeechEncoder
3025

3126
from src.flow_matching.models import ConditionalFlowMatchingWithBigVGan
27+
from src.flow_matching.utils.whisper import WhisperFeatureExtractor, WhisperEncoder
3228

3329
wav_path = "/path/to/wav"
3430

35-
encoder = SpeechEncoder.by_name(
36-
dense_model_name="mhubert-base-vp_mls_cv_8lang",
37-
quantizer_model_name="kmeans-expresso",
38-
vocab_size=2000,
39-
deduplicate=False,
40-
need_f0=False,
41-
).cuda()
31+
# load model and processor
32+
feature_extractor = WhisperFeatureExtractor.from_pretrained("ryota-komatsu/whisper-large-v3-tokenizer")
33+
encoder = WhisperEncoder.from_pretrained("ryota-komatsu/whisper-large-v3-tokenizer").cuda()
4234

4335
# download a pretrained model from hugging face hub
4436
decoder = ConditionalFlowMatchingWithBigVGan.from_pretrained("ryota-komatsu/flow_matching_with_bigvgan").cuda()
@@ -47,8 +39,16 @@ decoder = ConditionalFlowMatchingWithBigVGan.from_pretrained("ryota-komatsu/flow
4739
waveform, sr = torchaudio.load(wav_path)
4840
waveform = torchaudio.functional.resample(waveform, sr, 16000)
4941

42+
input_features = feature_extractor(
43+
waveform.squeeze(0).numpy(),
44+
return_tensors="pt",
45+
sampling_rate=16000,
46+
device="cuda",
47+
padding="do_not_pad",
48+
).input_features.to("cuda")
49+
5050
# encode a waveform into pseudo-phonetic units
51-
units = encoder(waveform.cuda())["units"]
51+
units = encoder(input_features, out_layer=15)
5252
units = units.unsqueeze(0) + 1 # 0: pad
5353

5454
# resynthesis
@@ -60,21 +60,17 @@ audio_values = decoder(units)
6060
```python
6161
import torch
6262
import torchaudio
63-
from textless.data.speech_encoder import SpeechEncoder
6463
from tokenizers import Tokenizer
6564
from transformers import LlamaForCausalLM
6665

66+
from src.flow_matching.utils.whisper import WhisperFeatureExtractor, WhisperEncoder
6767
from src.speechlm.utils import convert_units_to_unicode
6868

6969
wav_path = "/path/to/wav"
7070

71-
encoder = SpeechEncoder.by_name(
72-
dense_model_name="hubert-base-ls960",
73-
quantizer_model_name="kmeans",
74-
vocab_size=100,
75-
deduplicate=True,
76-
need_f0=False,
77-
).cuda()
71+
# load model and processor
72+
feature_extractor = WhisperFeatureExtractor.from_pretrained("ryota-komatsu/whisper-large-v3-tokenizer")
73+
encoder = WhisperEncoder.from_pretrained("ryota-komatsu/whisper-large-v3-tokenizer").cuda()
7874

7975
# BPE tokenizer
8076
tokenizer = Tokenizer.from_file("/path/to/pretrained/tokenizer.json")
@@ -85,8 +81,16 @@ model = LlamaForCausalLM.from_pretrained("/path/to/pretrained/model").cuda()
8581
waveform, sr = torchaudio.load(wav_path)
8682
waveform = torchaudio.functional.resample(waveform, sr, 16000)
8783

84+
input_features = feature_extractor(
85+
waveform.squeeze(0).numpy(),
86+
return_tensors="pt",
87+
sampling_rate=16000,
88+
device="cuda",
89+
padding="do_not_pad",
90+
).input_features.to("cuda")
91+
8892
# encode a waveform into pseudo-phonetic units
89-
units = encoder(waveform.cuda())["units"].tolist()
93+
units = encoder(input_features, out_layer=15).tolist()
9094
unicodes = convert_units_to_unicode(units)
9195

9296
# BPE
@@ -105,7 +109,7 @@ Jupyter notebook demo is found [here](demo.ipynb).
105109

106110
## Data Preparation
107111

108-
If you already have LibriTTS-R, you can use it by editing [a config file](configs/unit2speech/mhubert-expresso-2000.yaml#L6);
112+
If you already have LibriTTS-R, you can use it by editing [a config file](configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml#L7);
109113
```yaml
110114
dataset:
111115
wav_dir_orig: "/path/to/LibriTTS-R" # ${dataset.wav_dir_orig}/train-clean-100, train-clean-360, ...
@@ -128,23 +132,15 @@ sh scripts/download_slm21.sh # download sWUGGY and sBLIMP
128132

129133
## Training a unit-to-speech synthesizer
130134

131-
```shell
132-
python main_resynth.py --config=configs/unit2speech/mhubert-expresso-2000.yaml
133-
```
134-
135135
To run only a specific stage, pass it as an argument.
136136

137137
Supported processing stages
138138
1. resample
139-
1. tokenize
140-
1. extract_features
141-
1. train_bigvgan # can be skipped when using a pretrained model
142139
1. train_flow_matching
143-
1. evaluate
144140
1. synthesize
145141

146142
```shell
147-
python main_resynth.py tokenize --config=configs/unit2speech/mhubert-expresso-2000.yaml
143+
python main_resynth.py train_flow_matching --config=configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml
148144
```
149145

150146
## Training a speech language model
@@ -159,7 +155,7 @@ torchrun \
159155
--rdzv_backend=c10d \
160156
--rdzv_endpoint=localhost:29400 \
161157
main_speechlm.py \
162-
--config=configs/speechlm/hubert.yaml
158+
--config=configs/speechlm/whisper.yaml
163159
```
164160

165161
To run only a sub-task (encode, tokenize, or train), specify it as an argument.
@@ -172,13 +168,13 @@ torchrun \
172168
--rdzv_backend=c10d \
173169
--rdzv_endpoint=localhost:29400 \
174170
main_speechlm.py encode \
175-
--config=configs/speechlm/hubert.yaml
171+
--config=configs/speechlm/whisper.yaml
176172
```
177173

178174
## Evaluation of a speech language model
179175

180176
See [Zero Resource Speech homepage](https://zerospeech.com/tasks/task_4/tasks_goals/) and [paper](https://arxiv.org/abs/2011.11588) for task details.
181177

182178
```shell
183-
python main_speechlm.py eval --config=configs/speechlm/hubert.yaml
179+
python main_speechlm.py eval --config=configs/speechlm/whisper.yaml
184180
```
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ optim:
4242
total_steps: 200000
4343

4444
s2u:
45-
dense_model_name: "hubert-base-ls960"
46-
quantizer_model_name: "kmeans"
47-
vocab_size: 100
45+
name: "ryota-komatsu/whisper-large-v3-tokenizer"
46+
vocab_size: 4096
4847

4948
tokenizer_path: "models/speechlm/hubert/tokenizer.json"
5049

configs/unit2speech/mhubert-expresso-2000-duration-prediction.yaml

Lines changed: 0 additions & 102 deletions
This file was deleted.

configs/unit2speech/mhubert-expresso-2000.yaml

Lines changed: 0 additions & 102 deletions
This file was deleted.

0 commit comments

Comments
 (0)