Skip to content

[Not for merge] Add Emilia Training Recipe for Llasa (cosyvoice2 token) #1887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions egs/emilia/TTS/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Results
| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment |
|---------------------------------------|----------|-----------|--------|
| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)|
| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)|
| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.81% (16 steps) | |

# Introduction

[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k
hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation.

See https://arxiv.org/pdf/2407.05361.

# Llasa (cosyvoice2 token)

./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens.

Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main).

Preparation:

```
# extract cosyvoice2 semantic tokens
bash prepare.sh --stage 3 --stop_stage 4

# Or you could use the prepared tokens.
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
```

The training command is given below:

```
# docker: ghcr.io/swivid/f5-tts:main
# pip install -r llasa_cosyvoice2_token/requirements.txt

WANDB_KEY=$your_wandb_key
wandb login ${WANDB_KEY}
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node=8 train.py config.json
```

To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880):
```
cd icefall/egs/wenetspeech4tts/TTS
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
vocoder=./bigvgan_v2_24khz_100band_256x
split=test_zh
llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000

huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
torchrun --nproc_per_node=2 \
f5-tts/infer_dist.py \
--output_dir $output_dir \
--batch_size 1 \
--num_workers 2 \
--llm-model-name-or-path $llm_path \
--flow-matching-model-path $model_path \
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--use-cosyvoice-semantic-token True \
--vocoder-dir $vocoder \
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
# compute cer
huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
bash local/compute_wer.sh $output_dir $manifest
```

# Credits
- [Llasa](https://arxiv.org/abs/2502.04128)
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main)
27 changes: 27 additions & 0 deletions egs/emilia/TTS/llasa_cosyvoice2_token/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct",
"data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"],
"bf16": false,
"output_dir": "./exp_zh",
"num_train_epochs": 3,
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"gradient_accumulation_steps": 1,
"evaluation_strategy": "steps",
"eval_steps": 1000,
"save_strategy": "steps",
"save_steps": 5000,
"save_total_limit": 100,
"learning_rate": 0.00005,
"weight_decay": 0.01,
"adam_beta2": 0.95,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine",
"logging_steps": 100,
"report_to": "wandb",
"model_max_length": 2048,
"gradient_checkpointing": false,
"dataloader_num_workers": 4,
"dataloader_prefetch_factor": 4,
"deepspeed": "ds_config_zero2.json"
}
47 changes: 47 additions & 0 deletions egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 64,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},

"scheduler": {
"type": "WarmupCosineLR",
"params": {
"total_num_steps": "auto",
"warmup_min_ratio": 0.03,
"warmup_num_steps": "auto",
"cos_min_ratio": 0.1
}
},

"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
},

"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
8 changes: 8 additions & 0 deletions egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
torch
transformers
wandb
datasets
accelerate>=0.26.0
deepspeed
flash-attn
s3tokenizer
184 changes: 184 additions & 0 deletions egs/emilia/TTS/llasa_cosyvoice2_token/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py
""" Example Usage
WANDB_KEY=$your_wandb_key
wandb login ${WANDB_KEY}
huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node=8 train.py config.json
"""
import json
import os
import random
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
import transformers
import wandb
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Trainer,
TrainingArguments,
)
from transformers.trainer_pt_utils import LabelSmoother

IGNORE_TOKEN_ID = LabelSmoother.ignore_index
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"


@dataclass
class ModelArguments:
llm_model_name_or_path: Optional[str] = field(
default="meta-llama/Llama-3.2-1B-Instruct"
)


@dataclass
class DataArguments:
data_path: List[str] = field(
default=None,
metadata={"help": "Root path(s) to the data. Can be single path or list."},
)


@dataclass
class CustomTrainingArguments(TrainingArguments):
optim: str = field(default="adamw_torch_fused")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length"},
)
logging_steps: int = field(default=100, metadata={"help": "Log every X updates"})
report_to: Optional[str] = field(
default=None,
metadata={"help": "The integration to report the results and logs to."},
)
run_name: Optional[str] = field(
default=None, metadata={"help": "The name of the run for logging."}
)
gradient_checkpointing: bool = field(default=False)
lr_scheduler_type: str = field(
default="cosine", metadata={"help": "The learning rate scheduler to use."}
)
remove_unused_columns: bool = field(default=False)


def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048):
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
input_ids_list = []
for i, item in enumerate(batch):
text, code = item["text"], item["code"]
message = [
{"role": "user", "content": f"Convert the text to speech: {text}"},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
]

input_ids = tokenizer.apply_chat_template(
message,
tokenize=True,
chat_template=TEMPLATE,
)

code = [c + original_tokenizer_vocab_size for c in code]

idx = input_ids.index(speech_generation_start_index)
input_ids = input_ids[:idx] + code + input_ids[idx + 1 :]
if len(input_ids) < cut_off_len:
input_ids_list.append(input_ids)

max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
for input_ids in input_ids_list
]
input_ids = torch.tensor(input_ids_list, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)

target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
mask_indices = torch.where(input_ids == assistant_index)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": target_ids.to(dtype=torch.int64),
}


def main():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, CustomTrainingArguments)
)
assert len(sys.argv) == 2 and sys.argv[1].endswith(".json")
(
model_args,
data_args,
training_args,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))

is_main_process = training_args.local_rank in [-1, 0]
if training_args.report_to == "wandb" and is_main_process:
wandb.init(
project="llm_tts",
config=training_args.to_sanitized_dict(),
name=training_args.run_name,
)

model = AutoModelForCausalLM.from_pretrained(
model_args.llm_model_name_or_path,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)

tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path)
original_tokenizer_vocab_size = len(tokenizer)
cosyvoice2_token_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
"<|SPEECH_GENERATION_START|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)

model.resize_token_embeddings(len(tokenizer))
model.vocab_size = len(tokenizer)

dataset = load_dataset("json", data_files=data_args.data_path)
dataset = dataset["train"]
train_test_split = dataset.train_test_split(test_size=100, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]

trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=lambda features: data_collator(
features, tokenizer, original_tokenizer_vocab_size
),
)

if is_main_process:
trainer.add_callback(transformers.integrations.WandbCallback())

trainer.train(resume_from_checkpoint=None)
trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
main()
Loading
Loading