Skip to content

Commit a3b7271

Browse files
committed
add duration prediction
1 parent 0d475f7 commit a3b7271

12 files changed

Lines changed: 115 additions & 69 deletions

File tree

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
44
[![Python](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org)
5+
[![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ryota-komatsu/speech_resynth/blob/main/demo.ipynb)
56
[![model](https://img.shields.io/badge/%F0%9F%A4%97-Models-blue)](https://huggingface.co/ryota-komatsu/flow_matching_with_bigvgan)
67
[![dataset](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-blue)](https://huggingface.co/datasets/ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units)
78

@@ -48,7 +49,7 @@ input_features = feature_extractor(
4849
).input_features.to("cuda")
4950

5051
# encode a waveform into pseudo-phonetic units
51-
units = encoder(input_features, out_layer=15)
52+
units = encoder.encode(input_features)
5253
units = units.unsqueeze(0) + 1 # 0: pad
5354

5455
# resynthesis
@@ -90,7 +91,7 @@ input_features = feature_extractor(
9091
).input_features.to("cuda")
9192

9293
# encode a waveform into pseudo-phonetic units
93-
units = encoder(input_features, out_layer=15).tolist()
94+
units = encoder.encode(input_features).tolist()
9495
unicodes = convert_units_to_unicode(units)
9596

9697
# BPE
@@ -105,7 +106,7 @@ logits = model(input_ids=input_ids).logits
105106

106107
Visit [demo page](https://ryota-komatsu.github.io/speech_resynth) for speech samples.
107108

108-
Jupyter notebook demo is found [here](demo.ipynb).
109+
Google colab demo is found [here](https://colab.research.google.com/github/ryota-komatsu/speech_resynth/blob/main/demo.ipynb).
109110

110111
## Data Preparation
111112

@@ -136,6 +137,10 @@ To run only a specific stage, pass it as an argument.
136137

137138
Supported processing stages
138139
1. resample
140+
1. extract_features # can be skipped when using a pretrained BigVGan
141+
1. train_bigvgan # can be skipped when using a pretrained BigVGan
142+
1. train_tokenizer # can be skipped when using a pretrained model
143+
1. tokenize_dataset # can be skipped when using a Hugging Face datasets
139144
1. train_flow_matching
140145
1. synthesize
141146

configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ dataset:
55
name: "ryota-komatsu/LibriTTS-R-whisper-large-v3-4096units" # https://huggingface.co/datasets
66
wav_dir: "data/LibriTTS_R_16k" # ${root}/train-clean-100, train-clean-360, ...
77
wav_dir_orig: "data/LibriTTS_R" # if wav_dir == wav_dir_orig, original wav files are overwritten with 16 kHz waveforms
8+
spectrogram_dir: "data/LibriTTS_R_16k/spectrogram" # 34GB for BigVGAN
89
vad: false
910

1011
ext_audio: ".wav"
@@ -16,9 +17,6 @@ synthesis:
1617
split: "test-*"
1718
ext_audio: ${dataset.ext_audio}
1819

19-
eval:
20-
result_path: "results/resynth/score.csv"
21-
2220
flow_matching:
2321
path: "models/flow_matching"
2422
batch_size: 2700 # work with single 24GB VRAM GPU
@@ -36,9 +34,8 @@ flow_matching:
3634
dt: 0.1
3735
truncation_value: null # truncation trick (https://arxiv.org/abs/1809.11096)
3836

39-
vocab_size: ${tokenizer.vocab_size}
40-
4137
# src.flow_matching.configs.ConditionalFlowMatchingConfig
38+
vocab_size: ${tokenizer.vocab_size}
4239
dim_in: 80
4340
dim_cond_emb: 1280
4441
hidden_size: 256
@@ -105,9 +102,10 @@ vocoder:
105102
validation_interval: 10000
106103

107104
tokenizer:
105+
base: "openai/whisper-large-v3"
108106
name: "ryota-komatsu/whisper-large-v3-tokenizer"
109107
vocab_size: 4096
110-
out_layer: 15
108+
encoder_layers: 16
111109

112110
flow_matching_with_vocoder:
113111
name: "ryota-komatsu/flow_matching_with_bigvgan"

demo.ipynb

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# Textless Speech Resynthesis using Conditional Flow Matching and HuBERT units"
7+
"# Speech Resynthesis Using Conditional Flow Matching and Whisper Units"
88
]
99
},
1010
{
@@ -13,7 +13,19 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16-
"!pip install -r requirements.txt"
16+
"!pip install datasets==3.6.0 \\\n",
17+
" gcsfs==2025.3.0 \\\n",
18+
" nvidia-cublas-cu12==12.4.5.8 \\\n",
19+
" nvidia-cuda-cupti-cu12==12.4.127 \\\n",
20+
" nvidia-cuda-nvrtc-cu12==12.4.127 \\\n",
21+
" nvidia-cuda-runtime-cu12==12.4.127 \\\n",
22+
" nvidia-cudnn-cu12==9.1.0.70 \\\n",
23+
" nvidia-cufft-cu12==11.2.1.3 \\\n",
24+
" nvidia-curand-cu12==10.3.5.147 \\\n",
25+
" nvidia-cusolver-cu12==11.6.1.9 \\\n",
26+
" nvidia-cusparse-cu12==12.3.1.170 \\\n",
27+
" nvidia-nvjitlink-cu12==12.4.127 \\\n",
28+
" einx"
1729
]
1830
},
1931
{
@@ -111,7 +123,7 @@
111123
" padding=\"do_not_pad\",\n",
112124
").input_features.to(\"cuda\")\n",
113125
"\n",
114-
"units = encoder(input_features, out_layer=15)\n",
126+
"units = encoder.encode(input_features)\n",
115127
"units = units.unsqueeze(0) + 1 # 0: pad"
116128
]
117129
},

main_resynth.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,29 @@
22
from omegaconf import OmegaConf
33

44
from src.bigvgan.train import train_bigvgan
5-
from src.flow_matching.preprocess import resample
5+
from src.flow_matching.preprocess import extract_features, resample
66
from src.flow_matching.synthesize import synthesize
77
from src.flow_matching.train import train_flow_matching
8+
from src.flow_matching.utils.whisper import tokenize_dataset, train_tokenizer
89

910

1011
class TaskRunner:
1112
def resample(self, config: str = "configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml"):
1213
config = OmegaConf.load(config)
1314
resample(config)
1415

16+
def extract_features(self, config: str = "configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml"):
17+
config = OmegaConf.load(config)
18+
extract_features(config)
19+
20+
def train_tokenizer(self, config: str = "configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml"):
21+
config = OmegaConf.load(config)
22+
train_tokenizer(config)
23+
24+
def tokenize_dataset(self, config: str = "configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml"):
25+
config = OmegaConf.load(config)
26+
tokenize_dataset(config)
27+
1528
def train_bigvgan(self, config: str = "configs/unit2speech/whisper-large-v3-4096-bigvgan.yaml"):
1629
config = OmegaConf.load(config)
1730
train_bigvgan(config)

requirements.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
--extra-index-url https://download.pytorch.org/whl/cu121
22
accelerate==1.6.0
3-
backoff
43
datasets
54
einops==0.8.1
65
einx==0.3.0
@@ -11,11 +10,9 @@ jiwer @ git+https://github.com/jitsi/jiwer.git@c1b0d5e005431f5ce4fa6797f48639a8c
1110
lightning==2.5.1
1211
matplotlib==3.8.4
1312
numpy==1.22.0
14-
peft
1513
scikit-learn==1.4.2
1614
tensorboard==2.17.0
1715
torch==2.5.1+cu121
1816
torchaudio==2.5.1+cu121
19-
torchvision==0.20.1+cu121
2017
transformers==4.51.2
2118
zerospeech-benchmarks==0.9.4

src/flow_matching/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
class ConditionalFlowMatchingConfig(PretrainedConfig):
99
def __init__(
1010
self,
11-
vocab_size: int = 2000,
11+
vocab_size: int = 4096,
1212
dim_in: int = 80,
1313
dim_cond_emb: int = 768,
1414
hidden_size: int = 256,
1515
depth: int = 4,
1616
heads: int = 2,
17-
intermediate_size: int = 896,
17+
intermediate_size: int = 768,
1818
attn_dropout: float = 0.0,
1919
ff_dropout: float = 0.0,
2020
use_unet_skip_connection: bool = False,

src/flow_matching/data.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def get_collate_fn(
109109
wav_dir: Optional[str] = None,
110110
frames_per_seg: Optional[int] = None,
111111
ext_audio: str = ".wav",
112+
predict_duration: bool = False,
112113
):
114+
if predict_duration:
115+
assert frames_per_seg is None
116+
113117
def parse_item(item: Dict[str, Any]):
114118
input_ids = item["units"] + 1 # 0: pad
115119
spectrogram_labels = item["spectrogram"]
@@ -122,6 +126,11 @@ def parse_item(item: Dict[str, Any]):
122126
wav, sr = torchaudio.load(wav_path)
123127
wav = wav.squeeze(0)
124128

129+
if predict_duration:
130+
input_ids, durations = torch.unique_consecutive(input_ids, return_counts=True)
131+
else:
132+
durations = torch.ones_like(input_ids)
133+
125134
if frames_per_seg is not None:
126135
diff = len(input_ids) - frames_per_seg
127136

@@ -130,30 +139,34 @@ def parse_item(item: Dict[str, Any]):
130139
input_ids = input_ids[start : start + frames_per_seg]
131140
spectrogram_labels = spectrogram_labels[start : start + frames_per_seg]
132141

133-
return input_ids, spectrogram_labels, transcript, id, wav
142+
return input_ids, spectrogram_labels, durations, transcript, id, wav
134143

135144
def collate_fn(batch):
136145
input_ids = []
137146
spectrogram_labels = []
147+
duration_labels = []
138148
transcripts = []
139149
names = []
140150
input_values = []
141151

142152
for item in batch:
143-
units, spectrogram, transcript, id, wav = parse_item(item)
153+
units, spectrogram, durations, transcript, id, wav = parse_item(item)
144154
input_ids.append(units)
145155
spectrogram_labels.append(spectrogram)
156+
duration_labels.append(durations)
146157
transcripts.append(transcript)
147158
names.append(id)
148159
input_values.append(wav)
149160

150161
input_ids = pad_sequence(input_ids, batch_first=True)
151162
spectrogram_labels = pad_sequence(spectrogram_labels, batch_first=True, padding_value=-100)
163+
duration_labels = pad_sequence(duration_labels, batch_first=True)
152164
input_values = pad_sequence(input_values, batch_first=True)
153165

154166
return {
155167
"input_ids": input_ids,
156168
"spectrogram_labels": spectrogram_labels,
169+
"duration_labels": duration_labels,
157170
"transcripts": transcripts,
158171
"names": names,
159172
"input_values": input_values,

src/flow_matching/preprocess.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torchaudio
66
from tqdm import tqdm
77

8+
from ..bigvgan.data import mel_spectrogram
9+
810

911
def resample(config):
1012
wav_dir_orig = Path(config.dataset.wav_dir_orig)
@@ -27,3 +29,27 @@ def resample(config):
2729
wav_path.parent.mkdir(parents=True, exist_ok=True)
2830
wav_path = str(wav_path) # for sox backend
2931
torchaudio.save(wav_path, wav, 16000)
32+
33+
34+
def extract_features(config):
35+
wav_dir = Path(config.dataset.wav_dir)
36+
spectrogram_dir = Path(config.dataset.spectrogram_dir)
37+
wav_paths = list(wav_dir.glob("**/*" + config.dataset.ext_audio))
38+
39+
for wav_path in tqdm(wav_paths):
40+
wav_name = wav_path.relative_to(wav_dir).with_suffix("")
41+
spectrogram_path = spectrogram_dir / wav_name.with_suffix(".pt")
42+
if spectrogram_path.is_file():
43+
continue
44+
spectrogram_path.parent.mkdir(parents=True, exist_ok=True)
45+
46+
wav_path = str(wav_path)
47+
wav, sr = torchaudio.load(wav_path)
48+
wav = wav.cuda()
49+
wav = wav / wav.abs().max() * 0.95
50+
51+
spectrogram_labels = mel_spectrogram(wav) # (1, 80, len)
52+
spectrogram_labels = spectrogram_labels.transpose(1, 2) # (1, len, 80)
53+
spectrogram_labels = spectrogram_labels.cpu()
54+
55+
torch.save(spectrogram_labels, spectrogram_path)

src/flow_matching/synthesize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def synthesize(config):
3939
device="cuda",
4040
padding="do_not_pad",
4141
).input_features.to("cuda")
42-
units = encoder(input_features, out_layer=config.tokenizer.out_layer)
42+
units = encoder.encode(input_features)
4343
units = units + 1 # 0: pad
4444
input_ids.append(units)
4545

src/flow_matching/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ def validate(config, dataloader, model: ConditionalFlowMatchingModel, step: int,
112112
def train_flow_matching(config):
113113
fix_random_seed(config.common.seed)
114114

115-
train_set = load_dataset(config.dataset.name, split="train").with_format("torch")
116-
dev_set = load_dataset(config.dataset.name, split="dev").with_format("torch")
115+
train_set = load_dataset(config.dataset.name, split="train", keep_in_memory=True).with_format("torch")
116+
dev_set = load_dataset(config.dataset.name, split="dev", keep_in_memory=True).with_format("torch")
117+
117118
train_loader = torch.utils.data.DataLoader(
118119
train_set,
119120
batch_size=config.flow_matching.batch_size,
@@ -122,6 +123,7 @@ def train_flow_matching(config):
122123
collate_fn=get_collate_fn(
123124
frames_per_seg=config.flow_matching.frames_per_seg,
124125
ext_audio=config.dataset.ext_audio,
126+
predict_duration=config.flow_matching.predict_duration,
125127
),
126128
)
127129
dev_loader = torch.utils.data.DataLoader(
@@ -131,6 +133,7 @@ def train_flow_matching(config):
131133
wav_dir=config.dataset.wav_dir,
132134
frames_per_seg=config.flow_matching.frames_per_seg,
133135
ext_audio=config.dataset.ext_audio,
136+
predict_duration=config.flow_matching.predict_duration,
134137
),
135138
)
136139

@@ -189,6 +192,7 @@ def train_flow_matching(config):
189192
loss = model(
190193
input_ids=batch["input_ids"].cuda(),
191194
spectrogram_labels=batch["spectrogram_labels"].cuda(),
195+
duration_labels=batch["duration_labels"].cuda(),
192196
)
193197
scaler.scale(loss).backward()
194198

0 commit comments

Comments
 (0)