Skip to content

Commit dd5d7e3

Browse files
yuekaizhangyuekaizyuekaizyuekaizJinZr
authored
F5-TTS Training Recipe for WenetSpeech4TTS (k2-fsa#1846)
* add f5 * add infer * add dit * add README * update pretrained checkpoint usage --------- Co-authored-by: yuekaiz <[email protected]> Co-authored-by: yuekaiz <[email protected]> Co-authored-by: yuekaiz <[email protected]> Co-authored-by: zr_jin <[email protected]>
1 parent 39c466e commit dd5d7e3

File tree

20 files changed

+7115
-3
lines changed

20 files changed

+7115
-3
lines changed

egs/ljspeech/TTS/matcha/fbank.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class MatchaFbankConfig:
1717
win_length: int
1818
f_min: float
1919
f_max: float
20+
device: str = "cuda"
2021

2122

2223
@register_extractor
@@ -46,7 +47,7 @@ def extract(
4647
f"Mismatched sampling rate: extractor expects {expected_sr}, "
4748
f"got {sampling_rate}"
4849
)
49-
samples = torch.from_numpy(samples)
50+
samples = torch.from_numpy(samples).to(self.device)
5051
assert samples.ndim == 2, samples.shape
5152
assert samples.shape[0] == 1, samples.shape
5253

@@ -81,7 +82,7 @@ def extract(
8182
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
8283
).squeeze(0)
8384

84-
return mel.numpy()
85+
return mel.cpu().numpy()
8586

8687
@property
8788
def frame_shift(self) -> Seconds:

egs/wenetspeech4tts/TTS/README.md

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,69 @@ python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_
6868
--text-extractor pypinyin_initials_finals --top-p ${top_p}
6969
```
7070

71+
# [F5-TTS](https://arxiv.org/abs/2410.06885)
72+
73+
./f5-tts contains the code for training F5-TTS model.
74+
75+
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-small-wenetspeech4tts-basic/tensorboard).
76+
77+
Preparation:
78+
79+
```
80+
bash prepare.sh --stage 5 --stop_stage 6
81+
```
82+
(Note: To compatiable with F5-TTS official checkpoint, we direclty use `vocab.txt` from [here.](https://github.com/SWivid/F5-TTS/blob/129014c5b43f135b0100d49a0c6804dd4cf673e1/data/Emilia_ZH_EN_pinyin/vocab.txt) To generate your own `vocab.txt`, you may refer to [the script](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/train/datasets/prepare_emilia.py).)
83+
84+
The training command is given below:
85+
86+
```
87+
# docker: ghcr.io/swivid/f5-tts:main
88+
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
89+
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
90+
91+
world_size=8
92+
exp_dir=exp/f5-tts-small
93+
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
94+
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
95+
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
96+
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
97+
--exp-dir ${exp_dir} --world-size ${world_size}
98+
```
99+
100+
To inference with Icefall Wenetspeech4TTS trained F5-Small, use:
101+
```
102+
huggingface-cli login
103+
huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset
104+
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-small-wenetspeech4tts-basic
105+
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
106+
107+
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
108+
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-56-avg-14.pt
109+
# skip
110+
python3 f5-tts/generate_averaged_model.py \
111+
--epoch 56 \
112+
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
113+
--exp-dir exp/f5_small
114+
115+
116+
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
117+
bash local/compute_wer.sh $output_dir $manifest
118+
```
119+
120+
To inference with official Emilia trained F5-Base, use:
121+
```
122+
huggingface-cli login
123+
huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset
124+
huggingface-cli download --local-dir F5-TTS SWivid/F5-TTS
125+
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
126+
127+
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
128+
model_path=./F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
129+
130+
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir
131+
bash local/compute_wer.sh $output_dir $manifest
132+
```
133+
71134
# Credits
72-
- [vall-e](https://github.com/lifeiteng/vall-e)
135+
- [VALL-E](https://github.com/lifeiteng/vall-e)
136+
- [F5-TTS](https://github.com/SWivid/F5-TTS)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#!/usr/bin/env python3
2+
#
3+
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
4+
# Copyright 2024 Yuekai Zhang
5+
#
6+
# See ../../../../LICENSE for clarification regarding multiple authors
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
"""
20+
Usage:
21+
(1) use the checkpoint exp_dir/epoch-xxx.pt
22+
python3 bin/generate_averaged_model.py \
23+
--epoch 40 \
24+
--avg 5 \
25+
--exp-dir ${exp_dir}
26+
27+
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
28+
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
29+
"""
30+
31+
32+
import argparse
33+
from pathlib import Path
34+
35+
import k2
36+
import torch
37+
from train import add_model_arguments, get_model
38+
39+
from icefall.checkpoint import (
40+
average_checkpoints,
41+
average_checkpoints_with_averaged_model,
42+
find_checkpoints,
43+
)
44+
from icefall.utils import AttributeDict
45+
46+
47+
def get_parser():
48+
parser = argparse.ArgumentParser(
49+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
50+
)
51+
52+
parser.add_argument(
53+
"--epoch",
54+
type=int,
55+
default=30,
56+
help="""It specifies the checkpoint to use for decoding.
57+
Note: Epoch counts from 1.
58+
You can specify --avg to use more checkpoints for model averaging.""",
59+
)
60+
61+
parser.add_argument(
62+
"--iter",
63+
type=int,
64+
default=0,
65+
help="""If positive, --epoch is ignored and it
66+
will use the checkpoint exp_dir/checkpoint-iter.pt.
67+
You can specify --avg to use more checkpoints for model averaging.
68+
""",
69+
)
70+
71+
parser.add_argument(
72+
"--avg",
73+
type=int,
74+
default=9,
75+
help="Number of checkpoints to average. Automatically select "
76+
"consecutive checkpoints before the checkpoint specified by "
77+
"'--epoch' and '--iter'",
78+
)
79+
80+
parser.add_argument(
81+
"--exp-dir",
82+
type=str,
83+
default="zipformer/exp",
84+
help="The experiment dir",
85+
)
86+
add_model_arguments(parser)
87+
return parser
88+
89+
90+
@torch.no_grad()
91+
def main():
92+
parser = get_parser()
93+
94+
args = parser.parse_args()
95+
args.exp_dir = Path(args.exp_dir)
96+
97+
params = AttributeDict()
98+
params.update(vars(args))
99+
100+
if params.iter > 0:
101+
params.suffix = f"checkpoint-{params.iter}-avg-{params.avg}"
102+
else:
103+
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
104+
105+
print("Script started")
106+
107+
device = torch.device("cpu")
108+
print(f"Device: {device}")
109+
110+
print("About to create model")
111+
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
112+
checkpoint = torch.load(filename, map_location=device)
113+
args = AttributeDict(checkpoint)
114+
model = get_model(args)
115+
116+
if params.iter > 0:
117+
# TODO FIX ME
118+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
119+
: params.avg + 1
120+
]
121+
if len(filenames) == 0:
122+
raise ValueError(
123+
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
124+
)
125+
elif len(filenames) < params.avg + 1:
126+
raise ValueError(
127+
f"Not enough checkpoints ({len(filenames)}) found for"
128+
f" --iter {params.iter}, --avg {params.avg}"
129+
)
130+
filename_start = filenames[-1]
131+
filename_end = filenames[0]
132+
print(
133+
"Calculating the averaged model over iteration checkpoints"
134+
f" from {filename_start} (excluded) to {filename_end}"
135+
)
136+
model.to(device)
137+
model.load_state_dict(
138+
average_checkpoints_with_averaged_model(
139+
filename_start=filename_start,
140+
filename_end=filename_end,
141+
device=device,
142+
)
143+
)
144+
filename = params.exp_dir / f"checkpoint-{params.iter}-avg-{params.avg}.pt"
145+
torch.save({"model": model.state_dict()}, filename)
146+
else:
147+
assert params.avg > 0, params.avg
148+
start = params.epoch - params.avg
149+
assert start >= 1, start
150+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
151+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
152+
print(
153+
f"Calculating the averaged model over epoch range from "
154+
f"{start} (excluded) to {params.epoch}"
155+
)
156+
filenames = [
157+
f"{params.exp_dir}/epoch-{i}.pt" for i in range(start, params.epoch + 1)
158+
]
159+
model.to(device)
160+
model.load_state_dict(average_checkpoints(filenames, device=device))
161+
162+
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
163+
checkpoint["model"] = model.state_dict()
164+
torch.save(checkpoint, filename)
165+
166+
num_param = sum([p.numel() for p in model.parameters()])
167+
print(f"Number of model parameters: {num_param}")
168+
169+
print("Done!")
170+
171+
172+
if __name__ == "__main__":
173+
main()

0 commit comments

Comments
 (0)