From 540430d2133fb771f7c7f7df1463a03297b29b84 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Fri, 28 Feb 2025 10:01:57 +0800 Subject: [PATCH 1/8] add token extraction --- .../TTS/local/extract_cosyvoice2_token.py | 205 ++++++++++++++++++ egs/emilia/TTS/prepare.sh | 178 +++++++++++++++ egs/emilia/TTS/shared | 1 + 3 files changed, 384 insertions(+) create mode 100644 egs/emilia/TTS/local/extract_cosyvoice2_token.py create mode 100755 egs/emilia/TTS/prepare.sh create mode 120000 egs/emilia/TTS/shared diff --git a/egs/emilia/TTS/local/extract_cosyvoice2_token.py b/egs/emilia/TTS/local/extract_cosyvoice2_token.py new file mode 100644 index 0000000000..2c1ccda766 --- /dev/null +++ b/egs/emilia/TTS/local/extract_cosyvoice2_token.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage +cpu: + +s3tokenizer --data_dir xxx.scp \ + --device "cpu" \ + --output_dir "./" \ + --batch_size 32 + +gpu: + +torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + `which s3tokenizer` --data_dir xxx.scp \ + --device "cuda" \ + --output_dir "./" \ + --batch_size 32 + +""" + +import argparse +import json +import os +from pathlib import Path + +import s3tokenizer +import torch +import torch.distributed as dist +from lhotse.serialization import load_jsonl +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm + + +class AudioDataset(Dataset): + def __init__(self, data_dir, jsonl_file): + self.data = [] + # convert data_dir to Path object + self.data_dir = Path(data_dir) + # jsonl_files = self.data_dir.glob("*.jsonl") + jsonl_files = [self.data_dir / jsonl_file] + for jsonl_file in jsonl_files: + for item in tqdm( + # Note: People's Speech manifest.json is really a JSONL. + load_jsonl(jsonl_file), + desc=f"Processing {jsonl_file}", + ): + self.data.append(item) + break + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + file_path = self.data_dir / self.data[idx]["wav"] + audio = s3tokenizer.load_audio(file_path) + if audio.shape[0] / 16000 > 30: + print( + f"do not support extract speech token for audio longer than 30s, file_path: {file_path}" # noqa + ) + mel = torch.zeros(128, 0) + else: + mel = s3tokenizer.log_mel_spectrogram(audio) + return self.data[idx], mel + + +def collate_fn(batch): + keys = [item[0] for item in batch] + mels = [item[1] for item in batch] + mels, mels_lens = s3tokenizer.padding(mels) + return keys, mels, mels_lens + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--model", + required=True, + type=str, + choices=[ + "speech_tokenizer_v1", + "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz", + ], + help="model version", + ) + parser.add_argument( + "--data_dir", + required=True, + type=str, + help="each line contains `wav_name wav_path`", + ) + parser.add_argument( + "--jsonl_file", + required=True, + type=str, + help="each line contains `wav_name wav_path`", + ) + parser.add_argument( + "--device", + required=True, + type=str, + choices=["cuda", "cpu"], + help="device for inference", + ) + parser.add_argument( + "--output_dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch_size", + required=True, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + if args.device == "cuda": + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + else: + world_size, local_rank, rank = 1, 0, 0 + + device = torch.device(args.device) + model = s3tokenizer.load_model(args.model).to(device) + dataset = AudioDataset(args.data_dir, args.jsonl_file) + + if args.device == "cuda": + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank] + ) + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + else: + sampler = None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=collate_fn, + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + writer = open(f"{args.output_dir}/part_{rank + 1}_of_{world_size}", "w") + for keys, mels, mels_lens in dataloader: + codes, codes_lens = model(mels.to(device), mels_lens.to(device)) + for i, k in enumerate(keys): + code = codes[i, : codes_lens[i].item()].tolist() + k["code"] = code + writer.write(json.dumps(k, ensure_ascii=False) + "\n") + if rank == 0: + progress_bar.update(world_size * len(keys)) + + if rank == 0: + progress_bar.close() + writer.close() + if args.device == "cuda": + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/TTS/prepare.sh b/egs/emilia/TTS/prepare.sh new file mode 100755 index 0000000000..4a0d2df0b7 --- /dev/null +++ b/egs/emilia/TTS/prepare.sh @@ -0,0 +1,178 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +# pip install lhotse s3tokenizer +stage=6 +stop_stage=6 + +dl_dir=$PWD/download +dl_dir=/workspace_data/Emilia-Dataset/ +prefix="emilia" +# zh, en, ja, ko, de, fr +lang_set=("de" "en" "zh" "ja" "ko" "fr") +lang_set=("de" "en" "zh" "ja" "fr") +. shared/parse_options.sh || exit 1 + + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "dl_dir: $dl_dir" + log "Stage 0: Download data" + #huggingface-cli login + # huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS + + # Extract the downloaded data: + for lang in "${lang_set[@]}"; do + lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') + folder=$dl_dir/raw/${lang_upper} + for file in $folder/*.tar.gz; do + echo "Processing ${file}" + # e.g. $dl_dir/raw/DE/*tar.gz untar first, DE is the language code in upper case + tar -xzvf $file -C $folder + done + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare emilia manifest" + # We assume that you have downloaded the Emilia corpus + # to $dl_dir/emilia + mkdir -p data/manifests + for lang in "${lang_set[@]}"; do + echo "Processing ${lang}" + if [ ! -e data/manifests/.emilia.${lang}.done ]; then + lhotse prepare emilia $dl_dir data/manifests --num-jobs 30 --lang "${lang}" + touch data/manifests/.emilia.${lang}.done + fi + done +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Generate fbank (used by ./f5-tts)" + mkdir -p data/fbank + for lang in "${lang_set[@]}"; do + echo "Processing ${lang}" + if [ ! -e data/fbank/.emilia.${lang}.done ]; then + ./local/compute_mel_feat.py --dataset-parts $lang --split 100 --prefix ${prefix} + touch data/fbank/.emilia.${lang}.done + fi + done +fi + +if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then + log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" + if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then + echo "Combining ${prefix} cuts" + pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") + lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz + fi + if [ ! -e data/fbank/.${prefix}_split.done ]; then + echo "Splitting ${prefix} cuts into train, valid and test sets" + + lhotse subset --last 800 \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz + lhotse subset --first 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_valid.jsonl.gz + lhotse subset --last 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_test.jsonl.gz + + rm data/fbank/${prefix}_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) + lhotse subset --first $n \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_train.jsonl.gz + touch data/fbank/.${prefix}_split.done + fi +fi + +# zcat test.jsonl.gz | jq -r '.recording.id + " " + .recording.sources[0].source' > wav.scp +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" + data_dir=$dl_dir/raw/ZH + # for all jsonl files in data_dir + for jsonl_file in $data_dir/*.jsonl; do + # get the file basename + jsonl_file_basename=$(basename $jsonl_file) + echo "Processing $jsonl_file" + output_dir="./cosy_v2_tokens_ZH/${jsonl_file_basename%.jsonl}" + echo "output_dir: $output_dir" + # skip if the output_dir exists + if [ -e $output_dir ]; then + echo "Output directory $output_dir already exists, skipping" + continue + fi + mkdir -p $output_dir + torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + local/extract_cosyvoice2_token.py --data_dir $data_dir \ + --jsonl_file $jsonl_file_basename \ + --device "cuda" \ + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 2 \ + --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" + for lang in "${lang_set[@]}"; do + lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') + data_dir=$dl_dir/raw/${lang_upper} + # for all jsonl files in data_dir + for jsonl_file in $data_dir/*.jsonl; do + # get the file basename + jsonl_file_basename=$(basename $jsonl_file) + echo "Processing $jsonl_file" + output_dir="./cosy_v2_tokens_${lang_upper}/${jsonl_file_basename%.jsonl}" + echo "output_dir: $output_dir" + # skip if the output_dir exists + if [ -e $output_dir ]; then + echo "Output directory $output_dir already exists, skipping" + continue + fi + mkdir -p $output_dir + torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + local/extract_cosyvoice2_token.py --data_dir $data_dir \ + --jsonl_file $jsonl_file_basename \ + --device "cuda" \ + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 2 \ + --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz + done + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then +# cat EN_B00008.tar.gz.* > EN_B00008.tar.gz + for lang in "${lang_set[@]}"; do + lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') + cosy_token_dir="./cosy_v2_tokens_${lang_upper}" + for dir in $cosy_token_dir/*; do + echo "Processing $dir" + # get the file basename + dir_basename=$(basename $dir) + echo "dir_basename: $dir_basename" + cat $dir/part* > $dir/${dir_basename}.jsonl + done + cat $cosy_token_dir/${lang_upper}*/*.jsonl > $cosy_token_dir/cosy_v2_tokens_${lang_upper}.jsonl + done +fi diff --git a/egs/emilia/TTS/shared b/egs/emilia/TTS/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/emilia/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From fa6587010e299c4d9b501beeb79044469e260cb7 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Feb 2025 02:08:05 +0000 Subject: [PATCH 2/8] add training codes --- .../TTS/llasa_cosyvoice2_token/config.json | 27 +++ .../ds_config_zero2.json | 47 +++++ .../llasa_cosyvoice2_token/requirements.txt | 7 + egs/emilia/TTS/llasa_cosyvoice2_token/run.sh | 4 + .../TTS/llasa_cosyvoice2_token/train.py | 171 ++++++++++++++++++ 5 files changed, 256 insertions(+) create mode 100644 egs/emilia/TTS/llasa_cosyvoice2_token/config.json create mode 100644 egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json create mode 100644 egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt create mode 100644 egs/emilia/TTS/llasa_cosyvoice2_token/run.sh create mode 100644 egs/emilia/TTS/llasa_cosyvoice2_token/train.py diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json new file mode 100644 index 0000000000..06aeb51f1d --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json @@ -0,0 +1,27 @@ +{ + "llm_model_name_or_path": "/workspace/slam/icefall_omni/egs/speech_llm/SPEECH2SPEECH/models/Qwen2.5-0.5B-Instruct", + "data_path": ["../emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], + "bf16": false, + "output_dir": "./exp_zh", + "num_train_epochs": 3, + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 1, + "evaluation_strategy": "steps", + "eval_steps": 1000, + "save_strategy": "steps", + "save_steps": 5000, + "save_total_limit": 100, + "learning_rate": 0.00005, + "weight_decay": 0.01, + "adam_beta2": 0.95, + "warmup_ratio": 0.03, + "lr_scheduler_type": "cosine", + "logging_steps": 100, + "report_to": "wandb", + "model_max_length": 2048, + "gradient_checkpointing": false, + "dataloader_num_workers": 4, + "dataloader_prefetch_factor": 4, + "deepspeed": "ds_config_zero2.json" +} diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json b/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json new file mode 100644 index 0000000000..b0b139598f --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json @@ -0,0 +1,47 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 64, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupCosineLR", + "params": { + "total_num_steps": "auto", + "warmup_min_ratio": 0.03, + "warmup_num_steps": "auto", + "cos_min_ratio": 0.1 + } + }, + + "zero_optimization": { + "stage": 2, + "overlap_comm": false, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": 1.0, + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt new file mode 100644 index 0000000000..09e069d3a0 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt @@ -0,0 +1,7 @@ +torch +transformers +wandb +datasets +accelerate>=0.26.0 +deepspeed +flash-attn diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh b/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh new file mode 100644 index 0000000000..a78bba96b0 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh @@ -0,0 +1,4 @@ + +WANDB_KEY=df59308c1f07be8338a87497523163014442d605 # TODO Set YOUR KEY! +wandb login ${WANDB_KEY} +torchrun --nproc_per_node=8 train.py config.json diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py new file mode 100644 index 0000000000..159e483d75 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py @@ -0,0 +1,171 @@ +import json +import os +import random +import sys +from dataclasses import dataclass, field +from functools import partial +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import transformers +import wandb +from datasets import load_dataset, load_from_disk +from torch.utils.data import DataLoader, Dataset +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.trainer_pt_utils import LabelSmoother + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + + +@dataclass +class ModelArguments: + llm_model_name_or_path: Optional[str] = field( + default="meta-llama/Llama-3.2-1B-Instruct" + ) + + +@dataclass +class DataArguments: + data_path: List[str] = field( + default=None, + metadata={"help": "Root path(s) to the data. Can be single path or list."}, + ) + + +@dataclass +class CustomTrainingArguments(TrainingArguments): + optim: str = field(default="adamw_torch_fused") + model_max_length: int = field( + default=2048, + metadata={"help": "Maximum sequence length"}, + ) + logging_steps: int = field(default=100, metadata={"help": "Log every X updates"}) + report_to: Optional[str] = field( + default=None, + metadata={"help": "The integration to report the results and logs to."}, + ) + run_name: Optional[str] = field( + default=None, metadata={"help": "The name of the run for logging."} + ) + gradient_checkpointing: bool = field(default=False) + lr_scheduler_type: str = field( + default="cosine", metadata={"help": "The learning rate scheduler to use."} + ) + remove_unused_columns: bool = field(default=False) + + +def data_collator(batch, tokenizer): + speech_generation_start_index = tokenizer.convert_tokens_to_ids( + "<|SPEECH_GENERATION_START|>" + ) + assistant_index = tokenizer.convert_tokens_to_ids("assistant") + input_ids_list = [] + for i, item in enumerate(batch): + text, code = item["text"], item["code"] + message = [ + {"role": "user", "content": f"Convert the text to speech: {text}"}, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + input_ids = tokenizer.apply_chat_template( + message, + tokenize=True, + chat_template=TEMPLATE, + ) + + code = [c + 151665 for c in code] + + idx = input_ids.index(speech_generation_start_index) + input_ids = input_ids[:idx] + code + input_ids[idx + 1 :] + if len(input_ids) < 2048: + input_ids_list.append(input_ids) + + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids)) + for input_ids in input_ids_list + ] + input_ids = torch.tensor(input_ids_list, dtype=torch.int) + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + mask_indices = torch.where(input_ids == assistant_index) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 2 to skip: 'assistant', '\n' + target_ids[row, : col + 2] = IGNORE_TOKEN_ID + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": target_ids.to(dtype=torch.int64), + } + + +def main(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, CustomTrainingArguments) + ) + assert len(sys.argv) == 2 and sys.argv[1].endswith(".json") + ( + model_args, + data_args, + training_args, + ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + + is_main_process = training_args.local_rank in [-1, 0] + if training_args.report_to == "wandb" and is_main_process: + wandb.init( + project="llm_tts", + config=training_args.to_sanitized_dict(), + name=training_args.run_name, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.llm_model_name_or_path, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path) + new_tokens = [f"<|s_{i}|>" for i in range(6561)] + ["<|SPEECH_GENERATION_START|>"] + num_added_tokens = tokenizer.add_tokens(new_tokens) + + model.resize_token_embeddings(len(tokenizer)) + model.vocab_size = len(tokenizer) + + dataset = load_dataset("json", data_files=data_args.data_path) + dataset = dataset["train"] + train_test_split = dataset.train_test_split(test_size=100, seed=42) + train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=lambda features: data_collator(features, tokenizer), + ) + + if is_main_process: + trainer.add_callback(transformers.integrations.WandbCallback()) + + trainer.train(resume_from_checkpoint=None) + trainer.save_model(training_args.output_dir) + + +if __name__ == "__main__": + main() From 0f7ebb7ffbac4222856e7d4160b06d387abd4c69 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Feb 2025 06:05:58 +0000 Subject: [PATCH 3/8] add llasa infer --- egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py | 828 ++++++++++++++++++ 1 file changed, 828 insertions(+) create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py new file mode 100644 index 0000000000..6964a43be6 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx +# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small + +# command for text token input +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 + +# command for cosyvoice semantic token input +split=test_zh # seed_tts_eval test_zh +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True + +bash local/compute_wer.sh $output_dir $manifest +""" +import argparse +import logging +import math +import os +import random +import time +from pathlib import Path + +import datasets +import torch +import torch.nn.functional as F +import torchaudio +from accelerate import Accelerator +from bigvganinference import BigVGANInference +from model.cfm import CFM +from model.dit import DiT +from model.modules import MelSpec +from model.utils import convert_char_to_pinyin +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + interpolate_tokens, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--model-path", + type=str, + default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--nfe", + type=int, + default=16, + help="The number of steps for the neural ODE", + ) + + parser.add_argument( + "--manifest-file", + type=str, + default=None, + help="The manifest file in seed_tts_eval format", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default="results", + help="The output directory to save the generated wavs", + ) + + parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + + parser.add_argument( + "--interpolate-token", + type=str2bool, + default=True, + help="Interpolate semantic token to match mel frames for CosyVoice", + ) + + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + + add_model_arguments(parser) + return parser.parse_args() + + +def get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio, ref_sr = torchaudio.load(prompt_wav) + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) + if ref_rms < target_rms: + ref_audio = ref_audio * target_rms / ref_rms + assert ( + ref_audio.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio) + + # Text + if len(prompt_text[-1].encode("utf-8")) == 1: + prompt_text = prompt_text + " " + text = [prompt_text + gt_text] + if tokenizer == "pinyin": + text_list = convert_char_to_pinyin(text, polyphone=polyphone) + else: + text_list = text + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + final_text_list[bucket_i].extend(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + interpolate_token=False, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for i in range(len(dataset)): + utt = dataset[i]["id"] + ref_audio_org, ref_sr = ( + dataset[i]["prompt_audio"]["array"], + dataset[i]["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + audio_tokens = dataset[i]["target_audio_cosy2_tokens"] + prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + input_tokens = prompt_audio_tokens + audio_tokens + + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + + total_mel_len = len(input_tokens) + if not interpolate_token: + total_mel_len = int(total_mel_len / 4 * 15) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + if total_mel_len > max_tokens: + print( + f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + ) + continue + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def inference_speech_token( + cosyvoice, + tts_text, + prompt_text, + prompt_speech_16k, + stream=False, + speed=1.0, + text_frontend=True, +): + tokens = [] + prompt_text = cosyvoice.frontend.text_normalize( + prompt_text, split=False, text_frontend=text_frontend + ) + for i in cosyvoice.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ): + + tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) + ( + prompt_text_token, + prompt_text_token_len, + ) = cosyvoice.frontend._extract_text_token(prompt_text) + speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( + prompt_speech_16k + ) + + for i in cosyvoice.model.llm.inference( + text=tts_text_token.to(cosyvoice.model.device), + text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( + cosyvoice.model.device + ), + prompt_text=prompt_text_token.to(cosyvoice.model.device), + prompt_text_len=torch.tensor( + [prompt_text_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + prompt_speech_token=speech_token.to(cosyvoice.model.device), + prompt_speech_token_len=torch.tensor( + [speech_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + embedding=None, + ): + tokens.append(i) + return tokens, speech_token + + +def get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + interpolate_token=False, +): + + import sys + + # please change the path to the cosyvoice accordingly + sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + sys.path.append("/workspace/CosyVoice") + from cosyvoice.cli.cosyvoice import CosyVoice2 + + # please download the cosyvoice model first + cosyvoice = CosyVoice2( + "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False + ) + + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio_org, ref_sr = torchaudio.load(prompt_wav) + + # cosy voice + if ref_sr != 16000: + resampler = torchaudio.transforms.Resample(ref_sr, 16000) + ref_audio_16k = resampler(ref_audio_org) + else: + ref_audio_16k = ref_audio_org + audio_tokens, prompt_audio_tokens = inference_speech_token( + cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False + ) + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + assert ( + ref_audio_org.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Text + # if len(prompt_text[-1].encode("utf-8")) == 1: + # prompt_text = prompt_text + " " + # text = [prompt_text + gt_text] + # if tokenizer == "pinyin": + # text_list = convert_char_to_pinyin(text, polyphone=polyphone) + # else: + # text_list = text + + # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens + # prompt_audio_tokens shape 1, prompt_audio_tokens + # audio_tokens shape 1, audio_tokens + prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() + input_tokens = prompt_audio_tokens + audio_tokens + + # convert it into a list + # input_tokens_list = input_tokens.squeeze().cpu().tolist() + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len_compute = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + total_mel_len = len(input_tokens) + if not interpolate_token: + total_mel_len = int(total_mel_len / 4 * 15) + print( + f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def get_seedtts_testset_metainfo(metalst): + f = open(metalst) + lines = f.readlines() + f.close() + metainfo = [] + for line in lines: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) + metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) + return metainfo + + +def main(): + args = get_parser() + + accelerator = Accelerator() + device = f"cuda:{accelerator.process_index}" + if args.manifest_file: + metainfo = get_seedtts_testset_metainfo(args.manifest_file) + if not args.use_cosyvoice_semantic_token: + prompts_all = get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + ) + else: + prompts_all = get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) + else: + assert args.use_cosyvoice_semantic_token + dataset = datasets.load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + prompts_all = get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) + + vocoder = BigVGANInference.from_pretrained( + "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False + ) + vocoder = vocoder.eval().to(device) + + model = get_model(args).eval().to(device) + checkpoint = torch.load(args.model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) + else: + _ = load_checkpoint( + args.model_path, + model=model, + ) + + os.makedirs(args.output_dir, exist_ok=True) + + accelerator.wait_for_everyone() + start = time.time() + + with accelerator.split_between_processes(prompts_all) as prompts: + for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): + ( + utts, + ref_rms_list, + ref_mels, + ref_mel_lens, + total_mel_lens, + final_text_list, + ) = prompt + ref_mels = ref_mels.to(device) + ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + + if args.use_cosyvoice_semantic_token: + # concat final_text_list + max_len = max([len(tokens) for tokens in final_text_list]) + # pad tokens to the same length + for i, tokens in enumerate(final_text_list): + final_text_list[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + final_text_list = torch.stack(final_text_list).to(device) + + # Inference + with torch.inference_mode(): + generated, _ = model.sample( + cond=ref_mels, + text=final_text_list, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=args.nfe, + cfg_strength=2.0, + sway_sampling_coef=args.swaysampling, + no_ref_audio=False, + seed=args.seed, + ) + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + if ref_rms_list[i] < target_rms: + generated_wave = generated_wave * ref_rms_list[i] / target_rms + torchaudio.save( + f"{args.output_dir}/{utts[i]}.wav", + generated_wave, + target_sample_rate, + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + timediff = time.time() - start + print(f"Done batch inference in {timediff / 60 :.2f} minutes.") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() From d2b473ad99e8b00fb3048dd6fe742af3a139c9c3 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Feb 2025 09:54:22 +0000 Subject: [PATCH 4/8] add eval seed tts --- egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py | 347 +++++++++++++++++++ egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py | 108 ++++++ 2 files changed, 455 insertions(+) create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py new file mode 100644 index 0000000000..59e222a747 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py @@ -0,0 +1,347 @@ +# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage +cpu: + +s3tokenizer --data_dir xxx.scp \ + --device "cpu" \ + --output_dir "./" \ + --batch_size 32 + +gpu: + +torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + `which s3tokenizer` --data_dir xxx.scp \ + --device "cuda" \ + --output_dir "./" \ + --batch_size 32 + +""" + +import argparse +import json +import os +from pathlib import Path + +import s3tokenizer +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from bigvganinference import BigVGANInference +from datasets import load_dataset +from lhotse.serialization import load_jsonl +from llm_tts import LLMTTS +from model.modules import MelSpec +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + interpolate_tokens, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint + +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--s3-tokenizer-name", + required=False, + type=str, + choices=[ + "speech_tokenizer_v1", + "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz", + ], + help="model version", + ) + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + parser.add_argument( + "--output_dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch_size", + required=True, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="model version", + ) + parser.add_argument( + "--tokenizer-dir", + required=True, + type=str, + help="tokenizer dir", + ) + parser.add_argument( + "--vocoder-dir", + required=True, + type=str, + help="vocoder dir", + ) + parser.add_argument( + "--flow-matching-model-path", + required=True, + type=str, + help="flow matching model path", + ) + add_model_arguments(parser) + args = parser.parse_args() + return args + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def data_collator(batch, tokenizer, mel_spectrogram): + speech_generation_start_index = tokenizer.convert_tokens_to_ids( + "<|SPEECH_GENERATION_START|>" + ) + assistant_index = tokenizer.convert_tokens_to_ids("assistant") + target_sample_rate = 24000 + hop_length = 256 + target_rms = 0.1 + input_ids_list, ref_mel_list, ref_mel_len_list = [], [], [] + for i, item in enumerate(batch): + prompt_text, target_text, prompt_audio_codes = ( + item["prompt_text"], + item["target_text"], + item["prompt_audio_cosy2_tokens"], + ) + message = [ + { + "role": "user", + "content": f"Convert the text to speech: {prompt_text + target_text}", + }, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + input_ids = tokenizer.apply_chat_template( + message, + tokenize=True, + chat_template=TEMPLATE, + ) + + prompt_audio_codes = [c + 151665 for c in prompt_audio_codes] + + idx = input_ids.index(speech_generation_start_index) + input_ids = input_ids[:idx] + prompt_audio_codes + input_ids_list.append(input_ids) + + # get flow matching model's prompt mel spectrogram + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Duration in mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + ref_mel_list.append(ref_mel) + ref_mel_len_list.append(ref_mel_len) + + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + [tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids + for input_ids in input_ids_list + ] + input_ids = torch.tensor(input_ids_list, dtype=torch.int64) + attention_mask = input_ids.ne(tokenizer.pad_token_id).long() + ids = [item["id"] for item in batch] + + ref_mel_batch = padded_mel_batch(ref_mel_list) + ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "ids": ids, + "ref_mel_batch": ref_mel_batch, + "ref_mel_len_batch": ref_mel_len_batch, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + device = torch.device(f"cuda:{local_rank}") + model = LLMTTS( + model_dir=args.llm_model_name_or_path, + tokenizer_dir=args.tokenizer_dir, + s3_tokenizer_name=args.s3_tokenizer_name, + device=device, + ) + + vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False) + vocoder = vocoder.eval().to(device) + + flow_matching_model = get_model(args).eval().to(device) + _ = load_checkpoint( + args.flow_matching_model_path, + model=flow_matching_model, + ) + + dataset = load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + mel_spectrogram = MelSpec( + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24000, + mel_spec_type="bigvgan", + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram), + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + for batch in dataloader: + generate_codes = model.inference_batch( + batch["input_ids"], batch["attention_mask"] + ) + flow_matching_input_tokens, total_mel_lens = [], [] + for i, code in enumerate(generate_codes): + flow_matching_input_token = interpolate_tokens(code) + total_mel_len = len(flow_matching_input_token) + flow_matching_input_tokens.append(flow_matching_input_token) + total_mel_lens.append(total_mel_len) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[ + "ref_mel_len_batch" + ].to(device) + + max_len = max([len(tokens) for tokens in flow_matching_input_tokens]) + # pad tokens to the same length + for i, tokens in enumerate(flow_matching_input_tokens): + flow_matching_input_tokens[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device) + generated, _ = flow_matching_model.sample( + cond=ref_mels, + text=flow_matching_input_tokens, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=16, + cfg_strength=2.0, + sway_sampling_coef=-1, + no_ref_audio=False, + seed=0, + ) + + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + # if ref_rms_list[i] < target_rms: + # generated_wave = generated_wave * ref_rms_list[i] / target_rms + utt = batch["ids"][i] + torchaudio.save( + f"{args.output_dir}/{utt}.wav", + generated_wave, + target_sample_rate, + ) + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + if rank == 0: + progress_bar.close() + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py new file mode 100644 index 0000000000..bf878db513 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py + +import re +from pathlib import Path +from typing import List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class LLMTTS: + """ + LLM-TTS for text-to-speech generation. + """ + + def __init__( + self, + model_dir: Path, + tokenizer_dir: Path, + s3_tokenizer_name: str, + device: torch.device, + ): + """ + Initializes the LLMTTS model with the provided configurations and device. + + Args: + model_dir (Path): Directory containing the model and config files. + device (torch.device): The device (CPU/GPU) to run the model on. + """ + self.device = device + + self.model = AutoModelForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch.float16, + device_map=device, + attn_implementation="flash_attention_2", + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [ + "<|SPEECH_GENERATION_START|>" + ] + num_added_tokens = tokenizer.add_tokens(new_tokens) + tokenizer.padding_side = "left" + self.tokenizer = tokenizer + self.assistant_index = tokenizer.convert_tokens_to_ids("assistant") + + @torch.no_grad() + def inference_batch( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + temperature: float = 0.8, + top_k: float = 50, + top_p: float = 0.95, + ) -> torch.Tensor: + """ + Performs inference to generate speech from text, incorporating prompt audio and/or text. + + Args: + text (str): The text input to be converted to speech. + prompt_speech_path (Path): Path to the audio file used as a prompt. + prompt_text (str, optional): Transcript of the prompt audio. + gender (str): female | male. + pitch (str): very_low | low | moderate | high | very_high + speed (str): very_low | low | moderate | high | very_high + temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. + top_k (float, optional): Top-k sampling parameter. Default is 50. + top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + + Returns: + torch.Tensor: Generated waveform as a tensor. + """ + # Generate speech using the model + generated_ids = self.model.generate( + input_ids=input_ids.to(self.device), + attention_mask=attention_mask.to(self.device), + max_new_tokens=1024, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + + results = [] + generated_ids = generated_ids.cpu().tolist() + for i in range(len(generated_ids)): + assistant_index = generated_ids[i].index(self.assistant_index) + padding_index = len(generated_ids[i]) + result = generated_ids[i][assistant_index + 2 :] + result = [token - 151665 for token in result] + result = [token for token in result if token >= 0] + results.append(result) + return results From 7623939fbf4a8afc396ae7c489616d6322fa5dea Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Mar 2025 05:40:38 +0000 Subject: [PATCH 5/8] clean code --- egs/emilia/TTS/README.md | 94 ++ .../TTS/llasa_cosyvoice2_token/config.json | 4 +- .../llasa_cosyvoice2_token/requirements.txt | 1 + .../TTS/llasa_cosyvoice2_token/train.py | 27 +- .../TTS/local/extract_cosyvoice2_token.py | 21 +- egs/emilia/TTS/prepare.sh | 87 +- egs/wenetspeech4tts/TTS/README.md | 2 + egs/wenetspeech4tts/TTS/f5-tts/infer.py | 7 - egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py | 66 +- egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py | 828 ------------------ egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py | 28 +- egs/wenetspeech4tts/TTS/f5-tts/train.py | 14 +- 12 files changed, 207 insertions(+), 972 deletions(-) create mode 100644 egs/emilia/TTS/README.md delete mode 100644 egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py diff --git a/egs/emilia/TTS/README.md b/egs/emilia/TTS/README.md new file mode 100644 index 0000000000..363ea38422 --- /dev/null +++ b/egs/emilia/TTS/README.md @@ -0,0 +1,94 @@ +# Results +| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment | +|---------------------------------------|----------|-----------|--------| +| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)| +| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)| +| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.89% (16 steps) | | + +# Introduction + +[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k +hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation. + +See https://arxiv.org/pdf/2407.05361. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + + + +# Llasa (cosyvoice2 token) + +./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens. + +Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main). + +Preparation: + +``` +# extract cosyvoice2 semantic tokens +bash prepare.sh --stage 3 --stop_stage 4 + +# Or you could use the prepared tokens. +huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token +``` + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install -r llasa_cosyvoice2_token/requirements.txt +# pip install -r icefall/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt + +WANDB_KEY=$your_wandb_key +wandb login ${WANDB_KEY} +huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct +torchrun --nproc_per_node=8 train.py config.json +``` + +To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880): +``` +cd icefall/egs/wenetspeech4tts/TTS +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +split=test_zh +llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000 + +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct +# compute cer +huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +bash local/compute_wer.sh $output_dir $manifest +``` + +# Credits +- [Llasa](https://arxiv.org/abs/2502.04128) +- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) +- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json index 06aeb51f1d..858edae84d 100644 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json @@ -1,6 +1,6 @@ { - "llm_model_name_or_path": "/workspace/slam/icefall_omni/egs/speech_llm/SPEECH2SPEECH/models/Qwen2.5-0.5B-Instruct", - "data_path": ["../emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], + "llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct", + "data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], "bf16": false, "output_dir": "./exp_zh", "num_train_epochs": 3, diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt index 09e069d3a0..11574c1909 100644 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt @@ -5,3 +5,4 @@ datasets accelerate>=0.26.0 deepspeed flash-attn +s3tokenizer diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py index 159e483d75..e3c6fcae61 100644 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py @@ -1,3 +1,11 @@ +# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py +""" Example Usage +WANDB_KEY=$your_wandb_key +wandb login ${WANDB_KEY} +huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token +huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct +torchrun --nproc_per_node=8 train.py config.json +""" import json import os import random @@ -11,8 +19,7 @@ import torch.nn as nn import transformers import wandb -from datasets import load_dataset, load_from_disk -from torch.utils.data import DataLoader, Dataset +from datasets import load_dataset from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -65,7 +72,7 @@ class CustomTrainingArguments(TrainingArguments): remove_unused_columns: bool = field(default=False) -def data_collator(batch, tokenizer): +def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048): speech_generation_start_index = tokenizer.convert_tokens_to_ids( "<|SPEECH_GENERATION_START|>" ) @@ -84,11 +91,11 @@ def data_collator(batch, tokenizer): chat_template=TEMPLATE, ) - code = [c + 151665 for c in code] + code = [c + original_tokenizer_vocab_size for c in code] idx = input_ids.index(speech_generation_start_index) input_ids = input_ids[:idx] + code + input_ids[idx + 1 :] - if len(input_ids) < 2048: + if len(input_ids) < cut_off_len: input_ids_list.append(input_ids) max_len = max([len(input_ids) for input_ids in input_ids_list]) @@ -140,7 +147,11 @@ def main(): ) tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path) - new_tokens = [f"<|s_{i}|>" for i in range(6561)] + ["<|SPEECH_GENERATION_START|>"] + original_tokenizer_vocab_size = len(tokenizer) + cosyvoice2_token_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ + "<|SPEECH_GENERATION_START|>" + ] num_added_tokens = tokenizer.add_tokens(new_tokens) model.resize_token_embeddings(len(tokenizer)) @@ -157,7 +168,9 @@ def main(): args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, - data_collator=lambda features: data_collator(features, tokenizer), + data_collator=lambda features: data_collator( + features, tokenizer, original_tokenizer_vocab_size + ), ) if is_main_process: diff --git a/egs/emilia/TTS/local/extract_cosyvoice2_token.py b/egs/emilia/TTS/local/extract_cosyvoice2_token.py index 2c1ccda766..2a6d1d3805 100644 --- a/egs/emilia/TTS/local/extract_cosyvoice2_token.py +++ b/egs/emilia/TTS/local/extract_cosyvoice2_token.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Example Usage -cpu: - -s3tokenizer --data_dir xxx.scp \ - --device "cpu" \ - --output_dir "./" \ - --batch_size 32 - -gpu: - torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - `which s3tokenizer` --data_dir xxx.scp \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + local/extract_cosyvoice2_token.py --data_dir $data_dir \ + --jsonl_file $jsonl_file_basename \ --device "cuda" \ - --output_dir "./" \ - --batch_size 32 + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 2 \ + --model "speech_tokenizer_v2_25hz" """ diff --git a/egs/emilia/TTS/prepare.sh b/egs/emilia/TTS/prepare.sh index 4a0d2df0b7..8abcfaf612 100755 --- a/egs/emilia/TTS/prepare.sh +++ b/egs/emilia/TTS/prepare.sh @@ -4,16 +4,17 @@ set -eou pipefail # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -# pip install lhotse s3tokenizer -stage=6 -stop_stage=6 +stage=3 +stop_stage=4 + +# Please download the OpenDataLab format from HuggingFace, you can specify the revision argument to fc71e07e8572f5f3be1dbd02ed3172a4d298f152, which is the old format. +# https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07e8572f5f3be1dbd02ed3172a4d298f152 dl_dir=$PWD/download -dl_dir=/workspace_data/Emilia-Dataset/ + prefix="emilia" # zh, en, ja, ko, de, fr lang_set=("de" "en" "zh" "ja" "ko" "fr") -lang_set=("de" "en" "zh" "ja" "fr") . shared/parse_options.sh || exit 1 @@ -29,23 +30,20 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "dl_dir: $dl_dir" log "Stage 0: Download data" - #huggingface-cli login - # huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS - # Extract the downloaded data: + cat $dl_dir/raw/EN/EN_B00008.tar.gz.* > $dl_dir/raw/EN/EN_B00008.tar.gz for lang in "${lang_set[@]}"; do lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') folder=$dl_dir/raw/${lang_upper} for file in $folder/*.tar.gz; do echo "Processing ${file}" - # e.g. $dl_dir/raw/DE/*tar.gz untar first, DE is the language code in upper case tar -xzvf $file -C $folder done done fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare emilia manifest" + log "Stage 1: Prepare emilia manifest (used by ./f5-tts)" # We assume that you have downloaded the Emilia corpus # to $dl_dir/emilia mkdir -p data/manifests @@ -58,7 +56,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then done fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Generate fbank (used by ./f5-tts)" mkdir -p data/fbank @@ -71,67 +68,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then done fi -if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then - log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" - if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then - echo "Combining ${prefix} cuts" - pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") - lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz - fi - if [ ! -e data/fbank/.${prefix}_split.done ]; then - echo "Splitting ${prefix} cuts into train, valid and test sets" - - lhotse subset --last 800 \ - data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz - lhotse subset --first 400 \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - data/fbank/${prefix}_cuts_valid.jsonl.gz - lhotse subset --last 400 \ - data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - data/fbank/${prefix}_cuts_test.jsonl.gz - - rm data/fbank/${prefix}_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) - lhotse subset --first $n \ - data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - data/fbank/${prefix}_cuts_train.jsonl.gz - touch data/fbank/.${prefix}_split.done - fi -fi - -# zcat test.jsonl.gz | jq -r '.recording.id + " " + .recording.sources[0].source' > wav.scp -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" - data_dir=$dl_dir/raw/ZH - # for all jsonl files in data_dir - for jsonl_file in $data_dir/*.jsonl; do - # get the file basename - jsonl_file_basename=$(basename $jsonl_file) - echo "Processing $jsonl_file" - output_dir="./cosy_v2_tokens_ZH/${jsonl_file_basename%.jsonl}" - echo "output_dir: $output_dir" - # skip if the output_dir exists - if [ -e $output_dir ]; then - echo "Output directory $output_dir already exists, skipping" - continue - fi - mkdir -p $output_dir - torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - local/extract_cosyvoice2_token.py --data_dir $data_dir \ - --jsonl_file $jsonl_file_basename \ - --device "cuda" \ - --output_dir $output_dir \ - --batch_size 32 \ - --num_workers 2 \ - --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz - done -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Extract cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)" for lang in "${lang_set[@]}"; do lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') data_dir=$dl_dir/raw/${lang_upper} @@ -161,14 +99,13 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then done fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then -# cat EN_B00008.tar.gz.* > EN_B00008.tar.gz +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Merge cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)" for lang in "${lang_set[@]}"; do lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') cosy_token_dir="./cosy_v2_tokens_${lang_upper}" for dir in $cosy_token_dir/*; do echo "Processing $dir" - # get the file basename dir_basename=$(basename $dir) echo "dir_basename: $dir_basename" cat $dir/part* > $dir/${dir_basename}.jsonl diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md index 8329ae9484..9a48bd1969 100644 --- a/egs/wenetspeech4tts/TTS/README.md +++ b/egs/wenetspeech4tts/TTS/README.md @@ -186,3 +186,5 @@ bash local/compute_wer.sh $output_dir $manifest - [VALL-E](https://github.com/lifeiteng/vall-e) - [F5-TTS](https://github.com/SWivid/F5-TTS) - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) +- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) +- [Spark-TTS](https://github.com/SparkAudio/Spark-TTS) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 6964a43be6..b90657d0e2 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -108,13 +108,6 @@ def get_parser(): help="Interpolate semantic token to match mel frames for CosyVoice", ) - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - parser.add_argument( "--split-name", type=str, diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py index 59e222a747..636720f032 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,23 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py """ Example Usage -cpu: - -s3tokenizer --data_dir xxx.scp \ - --device "cpu" \ - --output_dir "./" \ - --batch_size 32 - -gpu: - -torchrun --nproc_per_node=8 --nnodes=1 \ - --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ - `which s3tokenizer` --data_dir xxx.scp \ - --device "cuda" \ - --output_dir "./" \ - --batch_size 32 - +split=test_zh +llm_path=f5-tts/exp_zh/checkpoint-805000 +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct """ import argparse @@ -81,16 +85,16 @@ def get_args(): help="huggingface dataset split name", ) parser.add_argument( - "--output_dir", required=True, type=str, help="dir to save result" + "--output-dir", required=True, type=str, help="dir to save result" ) parser.add_argument( - "--batch_size", + "--batch-size", required=True, type=int, help="batch size (per-device) for inference", ) parser.add_argument( - "--num_workers", type=int, default=4, help="workers for dataloader" + "--num-workers", type=int, default=4, help="workers for dataloader" ) parser.add_argument( "--prefetch", type=int, default=5, help="prefetch for dataloader" @@ -119,6 +123,24 @@ def get_args(): type=str, help="flow matching model path", ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) add_model_arguments(parser) args = parser.parse_args() return args @@ -285,7 +307,11 @@ def main(): for batch in dataloader: generate_codes = model.inference_batch( - batch["input_ids"], batch["attention_mask"] + batch["input_ids"], + batch["attention_mask"], + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, ) flow_matching_input_tokens, total_mel_lens = [], [] for i, code in enumerate(generate_codes): diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py deleted file mode 100644 index 6964a43be6..0000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer_llasa.py +++ /dev/null @@ -1,828 +0,0 @@ -#!/usr/bin/env python3 -# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py -""" -Usage: -# docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html -# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx -# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x -manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst -python3 f5-tts/generate_averaged_model.py \ - --epoch 56 \ - --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ - --exp-dir exp/f5_small - -# command for text token input -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 - -# command for cosyvoice semantic token input -split=test_zh # seed_tts_eval test_zh -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True - -bash local/compute_wer.sh $output_dir $manifest -""" -import argparse -import logging -import math -import os -import random -import time -from pathlib import Path - -import datasets -import torch -import torch.nn.functional as F -import torchaudio -from accelerate import Accelerator -from bigvganinference import BigVGANInference -from model.cfm import CFM -from model.dit import DiT -from model.modules import MelSpec -from model.utils import convert_char_to_pinyin -from tqdm import tqdm -from train import ( - add_model_arguments, - get_model, - get_tokenizer, - interpolate_tokens, - load_F5_TTS_pretrained_checkpoint, -) - -from icefall.checkpoint import load_checkpoint -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tokens", - type=str, - default="f5-tts/vocab.txt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--model-path", - type=str, - default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", - help="Path to the unique text tokens file", - ) - - parser.add_argument( - "--seed", - type=int, - default=0, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--nfe", - type=int, - default=16, - help="The number of steps for the neural ODE", - ) - - parser.add_argument( - "--manifest-file", - type=str, - default=None, - help="The manifest file in seed_tts_eval format", - ) - - parser.add_argument( - "--output-dir", - type=Path, - default="results", - help="The output directory to save the generated wavs", - ) - - parser.add_argument("-ss", "--swaysampling", default=-1, type=float) - - parser.add_argument( - "--interpolate-token", - type=str2bool, - default=True, - help="Interpolate semantic token to match mel frames for CosyVoice", - ) - - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - - parser.add_argument( - "--split-name", - type=str, - default="wenetspeech4tts", - choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], - help="huggingface dataset split name", - ) - - add_model_arguments(parser) - return parser.parse_args() - - -def get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, -): - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( - metainfo, desc="Processing prompts..." - ): - # Audio - ref_audio, ref_sr = torchaudio.load(prompt_wav) - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) - if ref_rms < target_rms: - ref_audio = ref_audio * target_rms / ref_rms - assert ( - ref_audio.shape[-1] > 5000 - ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio) - - # Text - if len(prompt_text[-1].encode("utf-8")) == 1: - prompt_text = prompt_text + " " - text = [prompt_text + gt_text] - if tokenizer == "pinyin": - text_list = convert_char_to_pinyin(text, polyphone=polyphone) - else: - text_list = text - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - if use_truth_duration: - gt_audio, gt_sr = torchaudio.load(gt_wav) - if gt_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) - gt_audio = resampler(gt_audio) - total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) - - # # test vocoder resynthesis - # ref_audio = gt_audio - else: - ref_text_len = len(prompt_text.encode("utf-8")) - gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len = ref_mel_len + int( - ref_mel_len / ref_text_len * gen_text_len / speed - ) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - final_text_list[bucket_i].extend(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def get_inference_prompt_cosy_voice_huggingface( - dataset, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, - interpolate_token=False, -): - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for i in range(len(dataset)): - utt = dataset[i]["id"] - ref_audio_org, ref_sr = ( - dataset[i]["prompt_audio"]["array"], - dataset[i]["prompt_audio"]["sampling_rate"], - ) - ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() - audio_tokens = dataset[i]["target_audio_cosy2_tokens"] - prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] - - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) - if ref_rms < target_rms: - ref_audio_org = ref_audio_org * target_rms / ref_rms - - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio_org) - else: - ref_audio = ref_audio_org - input_tokens = prompt_audio_tokens + audio_tokens - - if interpolate_token: - input_tokens = interpolate_tokens(input_tokens) - text_list = input_tokens - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - - total_mel_len = len(input_tokens) - if not interpolate_token: - total_mel_len = int(total_mel_len / 4 * 15) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - if total_mel_len > max_tokens: - print( - f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - ) - continue - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - # final_text_list[bucket_i].extend(text_list) - final_text_list[bucket_i].append(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def inference_speech_token( - cosyvoice, - tts_text, - prompt_text, - prompt_speech_16k, - stream=False, - speed=1.0, - text_frontend=True, -): - tokens = [] - prompt_text = cosyvoice.frontend.text_normalize( - prompt_text, split=False, text_frontend=text_frontend - ) - for i in cosyvoice.frontend.text_normalize( - tts_text, split=True, text_frontend=text_frontend - ): - - tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) - ( - prompt_text_token, - prompt_text_token_len, - ) = cosyvoice.frontend._extract_text_token(prompt_text) - speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( - prompt_speech_16k - ) - - for i in cosyvoice.model.llm.inference( - text=tts_text_token.to(cosyvoice.model.device), - text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( - cosyvoice.model.device - ), - prompt_text=prompt_text_token.to(cosyvoice.model.device), - prompt_text_len=torch.tensor( - [prompt_text_token.shape[1]], dtype=torch.int32 - ).to(cosyvoice.model.device), - prompt_speech_token=speech_token.to(cosyvoice.model.device), - prompt_speech_token_len=torch.tensor( - [speech_token.shape[1]], dtype=torch.int32 - ).to(cosyvoice.model.device), - embedding=None, - ): - tokens.append(i) - return tokens, speech_token - - -def get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - polyphone=True, - target_sample_rate=24000, - n_fft=1024, - win_length=1024, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - num_buckets=200, - min_secs=3, - max_secs=40, - interpolate_token=False, -): - - import sys - - # please change the path to the cosyvoice accordingly - sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") - sys.path.append("/workspace/CosyVoice") - from cosyvoice.cli.cosyvoice import CosyVoice2 - - # please download the cosyvoice model first - cosyvoice = CosyVoice2( - "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False - ) - - prompts_all = [] - - min_tokens = min_secs * target_sample_rate // hop_length - max_tokens = max_secs * target_sample_rate // hop_length - - batch_accum = [0] * num_buckets - utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( - [[] for _ in range(num_buckets)] for _ in range(6) - ) - - mel_spectrogram = MelSpec( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ) - - for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( - metainfo, desc="Processing prompts..." - ): - # Audio - ref_audio_org, ref_sr = torchaudio.load(prompt_wav) - - # cosy voice - if ref_sr != 16000: - resampler = torchaudio.transforms.Resample(ref_sr, 16000) - ref_audio_16k = resampler(ref_audio_org) - else: - ref_audio_16k = ref_audio_org - audio_tokens, prompt_audio_tokens = inference_speech_token( - cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False - ) - - ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) - if ref_rms < target_rms: - ref_audio_org = ref_audio_org * target_rms / ref_rms - assert ( - ref_audio_org.shape[-1] > 5000 - ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." - if ref_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) - ref_audio = resampler(ref_audio_org) - else: - ref_audio = ref_audio_org - - # Text - # if len(prompt_text[-1].encode("utf-8")) == 1: - # prompt_text = prompt_text + " " - # text = [prompt_text + gt_text] - # if tokenizer == "pinyin": - # text_list = convert_char_to_pinyin(text, polyphone=polyphone) - # else: - # text_list = text - - # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens - # prompt_audio_tokens shape 1, prompt_audio_tokens - # audio_tokens shape 1, audio_tokens - prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() - input_tokens = prompt_audio_tokens + audio_tokens - - # convert it into a list - # input_tokens_list = input_tokens.squeeze().cpu().tolist() - if interpolate_token: - input_tokens = interpolate_tokens(input_tokens) - text_list = input_tokens - - # Duration, mel frame length - ref_mel_len = ref_audio.shape[-1] // hop_length - if use_truth_duration: - gt_audio, gt_sr = torchaudio.load(gt_wav) - if gt_sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) - gt_audio = resampler(gt_audio) - total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) - - # # test vocoder resynthesis - # ref_audio = gt_audio - else: - ref_text_len = len(prompt_text.encode("utf-8")) - gen_text_len = len(gt_text.encode("utf-8")) - total_mel_len_compute = ref_mel_len + int( - ref_mel_len / ref_text_len * gen_text_len / speed - ) - total_mel_len = len(input_tokens) - if not interpolate_token: - total_mel_len = int(total_mel_len / 4 * 15) - print( - f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" - ) - - # to mel spectrogram - ref_mel = mel_spectrogram(ref_audio) - ref_mel = ref_mel.squeeze(0) - - # deal with batch - assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." - bucket_i = math.floor( - (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets - ) - - utts[bucket_i].append(utt) - ref_rms_list[bucket_i].append(ref_rms) - ref_mels[bucket_i].append(ref_mel) - ref_mel_lens[bucket_i].append(ref_mel_len) - total_mel_lens[bucket_i].append(total_mel_len) - # final_text_list[bucket_i].extend(text_list) - final_text_list[bucket_i].append(text_list) - - batch_accum[bucket_i] += total_mel_len - - if batch_accum[bucket_i] >= infer_batch_size: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - batch_accum[bucket_i] = 0 - ( - utts[bucket_i], - ref_rms_list[bucket_i], - ref_mels[bucket_i], - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) = ( - [], - [], - [], - [], - [], - [], - ) - - # add residual - for bucket_i, bucket_frames in enumerate(batch_accum): - if bucket_frames > 0: - prompts_all.append( - ( - utts[bucket_i], - ref_rms_list[bucket_i], - padded_mel_batch(ref_mels[bucket_i]), - ref_mel_lens[bucket_i], - total_mel_lens[bucket_i], - final_text_list[bucket_i], - ) - ) - # not only leave easy work for last workers - random.seed(666) - random.shuffle(prompts_all) - - return prompts_all - - -def padded_mel_batch(ref_mels): - max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() - padded_ref_mels = [] - for mel in ref_mels: - padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) - padded_ref_mels.append(padded_ref_mel) - padded_ref_mels = torch.stack(padded_ref_mels) - padded_ref_mels = padded_ref_mels.permute(0, 2, 1) - return padded_ref_mels - - -def get_seedtts_testset_metainfo(metalst): - f = open(metalst) - lines = f.readlines() - f.close() - metainfo = [] - for line in lines: - assert len(line.strip().split("|")) == 4 - utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") - utt = Path(utt).stem - gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") - if not os.path.isabs(prompt_wav): - prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) - metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) - return metainfo - - -def main(): - args = get_parser() - - accelerator = Accelerator() - device = f"cuda:{accelerator.process_index}" - if args.manifest_file: - metainfo = get_seedtts_testset_metainfo(args.manifest_file) - if not args.use_cosyvoice_semantic_token: - prompts_all = get_inference_prompt( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - ) - else: - prompts_all = get_inference_prompt_cosy_voice( - metainfo, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - interpolate_token=args.interpolate_token, - ) - else: - assert args.use_cosyvoice_semantic_token - dataset = datasets.load_dataset( - "yuekai/seed_tts_cosy2", - split=args.split_name, - trust_remote_code=True, - ) - prompts_all = get_inference_prompt_cosy_voice_huggingface( - dataset, - speed=1.0, - tokenizer="pinyin", - target_sample_rate=24_000, - n_mel_channels=100, - hop_length=256, - mel_spec_type="bigvgan", - target_rms=0.1, - use_truth_duration=False, - infer_batch_size=1, - interpolate_token=args.interpolate_token, - ) - - vocoder = BigVGANInference.from_pretrained( - "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False - ) - vocoder = vocoder.eval().to(device) - - model = get_model(args).eval().to(device) - checkpoint = torch.load(args.model_path, map_location="cpu") - if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: - model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) - else: - _ = load_checkpoint( - args.model_path, - model=model, - ) - - os.makedirs(args.output_dir, exist_ok=True) - - accelerator.wait_for_everyone() - start = time.time() - - with accelerator.split_between_processes(prompts_all) as prompts: - for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): - ( - utts, - ref_rms_list, - ref_mels, - ref_mel_lens, - total_mel_lens, - final_text_list, - ) = prompt - ref_mels = ref_mels.to(device) - ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) - total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) - - if args.use_cosyvoice_semantic_token: - # concat final_text_list - max_len = max([len(tokens) for tokens in final_text_list]) - # pad tokens to the same length - for i, tokens in enumerate(final_text_list): - final_text_list[i] = torch.tensor( - tokens + [-1] * (max_len - len(tokens)), dtype=torch.long - ) - final_text_list = torch.stack(final_text_list).to(device) - - # Inference - with torch.inference_mode(): - generated, _ = model.sample( - cond=ref_mels, - text=final_text_list, - duration=total_mel_lens, - lens=ref_mel_lens, - steps=args.nfe, - cfg_strength=2.0, - sway_sampling_coef=args.swaysampling, - no_ref_audio=False, - seed=args.seed, - ) - for i, gen in enumerate(generated): - gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) - gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) - - generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() - target_rms = 0.1 - target_sample_rate = 24_000 - if ref_rms_list[i] < target_rms: - generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save( - f"{args.output_dir}/{utts[i]}.wav", - generated_wave, - target_sample_rate, - ) - - accelerator.wait_for_everyone() - if accelerator.is_main_process: - timediff = time.time() - start - print(f"Done batch inference in {timediff / 60 :.2f} minutes.") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py index bf878db513..1d0fdc5c8d 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 SparkAudio # 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# 2025 Yuekai Zhang # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py +# Modified from https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py import re from pathlib import Path @@ -39,7 +40,9 @@ def __init__( Args: model_dir (Path): Directory containing the model and config files. - device (torch.device): The device (CPU/GPU) to run the model on. + tokenizer_dir (Path): Directory containing the tokenizer files. + s3_tokenizer_name (str): Name of the tokenizer file on S3. + device (torch.device): Device to run the model on. """ self.device = device @@ -51,7 +54,9 @@ def __init__( ) tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) - new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [ + self.original_vocab_size = len(tokenizer) + self.cosyvoice2_token_vocab_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(self.cosyvoice2_token_vocab_size)] + [ "<|SPEECH_GENERATION_START|>" ] num_added_tokens = tokenizer.add_tokens(new_tokens) @@ -67,42 +72,39 @@ def inference_batch( temperature: float = 0.8, top_k: float = 50, top_p: float = 0.95, + max_new_tokens: int = 1024, ) -> torch.Tensor: """ Performs inference to generate speech from text, incorporating prompt audio and/or text. Args: - text (str): The text input to be converted to speech. - prompt_speech_path (Path): Path to the audio file used as a prompt. - prompt_text (str, optional): Transcript of the prompt audio. - gender (str): female | male. - pitch (str): very_low | low | moderate | high | very_high - speed (str): very_low | low | moderate | high | very_high + input_ids (torch.Tensor): Input IDs for the model. + attention_mask (torch.Tensor): Attention mask for the model. temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. top_k (float, optional): Top-k sampling parameter. Default is 50. top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + max_new_tokens (int, optional): Maximum number of tokens to generate. Default is 1024. Returns: torch.Tensor: Generated waveform as a tensor. """ - # Generate speech using the model generated_ids = self.model.generate( input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), - max_new_tokens=1024, + max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, ) - results = [] generated_ids = generated_ids.cpu().tolist() for i in range(len(generated_ids)): assistant_index = generated_ids[i].index(self.assistant_index) padding_index = len(generated_ids[i]) + # WAR: harding coding assistant_index + 2, for the current template Assistant: \n result = generated_ids[i][assistant_index + 2 :] - result = [token - 151665 for token in result] + result = [token - self.original_vocab_size for token in result] result = [token for token in result if token >= 0] results.append(result) return results diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 5333b3f277..343d0c65ca 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -118,6 +118,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of Decoder layers.", ) + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -313,13 +320,6 @@ def get_parser(): help="perform OOM check on dataloader batches before starting training.", ) - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - add_model_arguments(parser) return parser From bc6e113eddc4b1590734d213304234636b6a17c6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Mar 2025 06:03:41 +0000 Subject: [PATCH 6/8] remove run.sh --- egs/emilia/TTS/llasa_cosyvoice2_token/run.sh | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 egs/emilia/TTS/llasa_cosyvoice2_token/run.sh diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh b/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh deleted file mode 100644 index a78bba96b0..0000000000 --- a/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh +++ /dev/null @@ -1,4 +0,0 @@ - -WANDB_KEY=df59308c1f07be8338a87497523163014442d605 # TODO Set YOUR KEY! -wandb login ${WANDB_KEY} -torchrun --nproc_per_node=8 train.py config.json From c4731921e65ac1aa0787bae7c81e5793f5418e1a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Mar 2025 10:56:38 +0000 Subject: [PATCH 7/8] update results --- egs/emilia/TTS/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/emilia/TTS/README.md b/egs/emilia/TTS/README.md index 363ea38422..367a418990 100644 --- a/egs/emilia/TTS/README.md +++ b/egs/emilia/TTS/README.md @@ -3,7 +3,7 @@ |---------------------------------------|----------|-----------|--------| | pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)| | pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)| -| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.89% (16 steps) | | +| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.81% (16 steps) | | # Introduction From 1653b76deb5bffec958d17cf5440ace4f776732f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Mar 2025 01:23:43 +0000 Subject: [PATCH 8/8] update readme and requirements --- egs/emilia/TTS/README.md | 18 ------------------ egs/wenetspeech4tts/TTS/README.md | 14 -------------- 2 files changed, 32 deletions(-) diff --git a/egs/emilia/TTS/README.md b/egs/emilia/TTS/README.md index 367a418990..d55ff10c30 100644 --- a/egs/emilia/TTS/README.md +++ b/egs/emilia/TTS/README.md @@ -12,22 +12,6 @@ hours of speech across six languages, covering a wide range of speaking styles t See https://arxiv.org/pdf/2407.05361. -> [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). -> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> -> By using this framework, you agree to the following: -> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> -> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> -> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> -> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. - - - - # Llasa (cosyvoice2 token) ./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens. @@ -48,9 +32,7 @@ The training command is given below: ``` # docker: ghcr.io/swivid/f5-tts:main -# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html # pip install -r llasa_cosyvoice2_token/requirements.txt -# pip install -r icefall/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt WANDB_KEY=$your_wandb_key wandb login ${WANDB_KEY} diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md index 9a48bd1969..f1c57d853d 100644 --- a/egs/wenetspeech4tts/TTS/README.md +++ b/egs/wenetspeech4tts/TTS/README.md @@ -9,20 +9,6 @@ [**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. -> [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). -> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> -> By using this framework, you agree to the following: -> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> -> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> -> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> -> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. - - # [VALL-E](https://arxiv.org/abs/2301.02111) ./valle contains the code for training VALL-E TTS model.