Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def __init__(
elif self.logger == "tensorboard":
from torch.utils.tensorboard import SummaryWriter

self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
self.writer = None
if self.accelerator.is_main_process:
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")

self.model = model

Expand Down Expand Up @@ -392,9 +394,9 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
self.accelerator.log(
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
)
if self.logger == "tensorboard":
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if self.logger == "tensorboard" and self.accelerator.is_main_process:
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)

if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
Expand Down
5 changes: 3 additions & 2 deletions src/f5_tts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ python src/f5_tts/train/datasets/prepare_libritts.py
python src/f5_tts/train/datasets/prepare_ljspeech.py
```

### 2. Create custom dataset with metadata.csv
### 2. Create custom dataset with CSV
Prepare a CSV with two columns using a required header: `audio_file|text`. Audio paths must be absolute.
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentence "Use guidance see" is grammatically incorrect. Consider revising to "For guidance, see" or "See usage guidance at".

Suggested change
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
For guidance, see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).

Copilot uses AI. Check for mistakes.

```bash
python src/f5_tts/train/datasets/prepare_csv_wavs.py
python src/f5_tts/train/datasets/prepare_csv_wavs.py /path/to/metadata.csv /path/to/output
```

## Training & Finetuning
Expand Down
140 changes: 78 additions & 62 deletions src/f5_tts/train/datasets/prepare_csv_wavs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
"""
Usage:
python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N]

CSV format (header required, "|" delimiter):
audio_file|text
/path/to/wavs/audio_0001.wav|Yo! Hello? Hello?
/path/to/wavs/audio_0002.wav|Hi, how are you doing today? I want to go shopping and buy me some lemons.

Notes:
- audio_file must be an absolute path.
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring states "audio_file must be an absolute path" but the actual code allows tilde expansion (expanduser). This creates an inconsistency between the documentation and implementation. If the intent is to allow tilde paths, the documentation should reflect this (e.g., "audio_file must be an absolute path or use tilde notation"). If not, the code should validate before expansion.

Suggested change
- audio_file must be an absolute path.
- audio_file should be a valid filesystem path (e.g., absolute, relative to the current working directory, or using ~ for the home directory).

Copilot uses AI. Check for mistakes.
"""

import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import subprocess
import sys
from contextlib import contextmanager

Expand All @@ -16,6 +29,7 @@
from importlib.resources import files
from pathlib import Path

import soundfile as sf
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new soundfile import is added at line 32 but there is no validation that soundfile is installed. While soundfile will be tried first in get_audio_duration, if it's not installed, every single audio file will trigger an exception and print a warning before falling back to ffprobe. This could result in significant console spam for large datasets. Consider adding an import check at module level or checking once at the start of processing whether soundfile is available.

Suggested change
import soundfile as sf
try:
import soundfile as sf
_SOUND_FILE_AVAILABLE = True
except ImportError:
sf = None # type: ignore[assignment]
_SOUND_FILE_AVAILABLE = False

Copilot uses AI. Check for mistakes.
import torchaudio
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
Expand All @@ -25,23 +39,19 @@

PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")


def is_csv_wavs_format(input_dataset_dir):
fpath = Path(input_dataset_dir)
metadata = fpath / "metadata.csv"
wavs = fpath / "wavs"
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()


# Configuration constants
BATCH_SIZE = 100 # Batch size for text conversion
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
THREAD_NAME_PREFIX = "AudioProcessor"
CHUNK_SIZE = 100 # Number of files to process per worker batch

executor = None # Global executor for cleanup


def is_csv_wavs_format(input_path):
fpath = Path(input_path).expanduser()
return fpath.is_file() and fpath.suffix.lower() == ".csv"


@contextmanager
def graceful_exit():
"""Context manager for graceful shutdown on signals"""
Expand Down Expand Up @@ -82,22 +92,27 @@ def process_audio_file(audio_path, text, polyphone):
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
"""Convert a list of texts to pinyin in batches."""
converted_texts = []
for i in range(0, len(texts), batch_size):
for i in tqdm(
range(0, len(texts), batch_size),
total=(len(texts) + batch_size - 1) // batch_size,
desc="Converting texts to pinyin",
):
batch = texts[i : i + batch_size]
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
converted_texts.extend(converted_batch)
return converted_texts


def prepare_csv_wavs_dir(input_dir, num_workers=None):
def prepare_csv_wavs_dir(input_path, num_workers=None):
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is named prepare_csv_wavs_dir which suggests it expects a directory, but it now accepts a CSV file path and the parameter is named input_path. Consider renaming the function to better reflect its actual purpose, such as prepare_csv_wavs or prepare_csv_dataset, since it no longer operates on a directory containing metadata.csv and wavs subdirectory.

Copilot uses AI. Check for mistakes.
global executor
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
input_dir = Path(input_dir)
metadata_path = input_dir / "metadata.csv"
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
if not is_csv_wavs_format(input_path):
raise ValueError(f"input must be a .csv file: {input_path}")
audio_path_text_pairs = read_audio_text_pairs(Path(input_path).expanduser().as_posix())

polyphone = True
total_files = len(audio_path_text_pairs)
if total_files == 0:
raise RuntimeError("No valid rows found in CSV.")

# Use provided worker count or calculate optimal number
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
Expand Down Expand Up @@ -155,10 +170,12 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):


def get_audio_duration(audio_path, timeout=5):
"""
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
Falls back to torchaudio.load() if ffprobe fails.
"""
"""Get the duration of an audio file in seconds with fallbacks."""
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timeout parameter is still defined in the function signature but is now only used for the ffprobe fallback (not for soundfile). The soundfile.info() call has no timeout, which means it could potentially hang indefinitely on corrupted or problematic audio files. Consider either removing the timeout parameter if it's no longer needed, or documenting that it only applies to the ffprobe fallback.

Suggested change
"""Get the duration of an audio file in seconds with fallbacks."""
"""
Get the duration of an audio file in seconds with fallbacks.
Note:
The ``timeout`` parameter applies only to the ``ffprobe`` subprocess
fallback, not to the initial ``soundfile.info()`` call (which has no
built-in timeout) or the ``torchaudio.info()`` fallback.
"""

Copilot uses AI. Check for mistakes.
try:
return sf.info(audio_path).duration
Comment on lines +174 to +175
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new implementation tries soundfile first, then ffprobe, then torchaudio.info as fallbacks. However, soundfile.info() may load or partially decode the audio file, which could be slower than ffprobe for certain formats. The original implementation used ffprobe first (which is typically fast as it just reads metadata) before falling back to loading the actual audio. Consider whether soundfile.info() is actually faster than ffprobe for the expected audio formats, or if the fallback order should be reconsidered.

Copilot uses AI. Check for mistakes.
except Exception as e:
print(f"Warning: soundfile failed for {audio_path} with error: {e}. Falling back to ffprobe.")

try:
cmd = [
"ffprobe",
Expand All @@ -178,27 +195,39 @@ def get_audio_duration(audio_path, timeout=5):
return float(duration_str)
raise ValueError("Empty duration string from ffprobe.")
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
try:
audio, sample_rate = torchaudio.load(audio_path)
return audio.shape[1] / sample_rate
except Exception as e:
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.info.")

try:
info = torchaudio.info(audio_path)
if info.sample_rate > 0:
return info.num_frames / info.sample_rate
raise ValueError("Invalid sample_rate from torchaudio.info.")
except Exception as e:
raise RuntimeError(f"failed to get duration for {audio_path}: {e}")
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message uses lowercase "failed" but similar RuntimeError messages in Python typically start with an uppercase letter for consistency with standard exception formatting. Consider capitalizing the first letter: "Failed to get duration for".

Suggested change
raise RuntimeError(f"failed to get duration for {audio_path}: {e}")
raise RuntimeError(f"Failed to get duration for {audio_path}: {e}")

Copilot uses AI. Check for mistakes.


def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []

parent = Path(csv_file_path).parent
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
csv_path = Path(csv_file_path).expanduser().absolute()
with open(csv_path.as_posix(), mode="r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:
if len(row) >= 2:
audio_file = row[0].strip() # First column: audio file path
text = row[1].strip() # Second column: text
audio_file_path = parent / audio_file
audio_text_pairs.append((audio_file_path.as_posix(), text))
header = next(reader, None)
if header is None:
return audio_text_pairs
if len(header) < 2 or header[0].strip() != "audio_file" or header[1].strip() != "text":
raise ValueError("CSV header must be: audio_file|text")
for row_idx, row in enumerate(reader, start=2):
if len(row) < 2:
continue
audio_file = row[0].strip()
text = row[1].strip()
if not audio_file:
continue
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code validates that audio_file is not empty (line 221-222) but does not validate that text is not empty. Empty text entries could cause issues downstream in the text processing pipeline. Consider adding validation for empty text values similar to the audio_file check.

Suggested change
continue
continue
if not text:
continue

Copilot uses AI. Check for mistakes.
audio_path = Path(audio_file).expanduser()
if not audio_path.is_absolute():
raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
Comment on lines +227 to +229
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absolute path validation occurs after calling expanduser() on the path. This means that paths like "~/audio.wav" will first be expanded to an absolute path (e.g., "/home/user/audio.wav") and will pass the is_absolute() check, even though the original path in the CSV was not absolute. This contradicts the stated requirement that audio paths in the CSV must be absolute. Consider checking if the path is absolute before calling expanduser() to enforce the documented requirement strictly.

Suggested change
audio_path = Path(audio_file).expanduser()
if not audio_path.is_absolute():
raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
audio_path = Path(audio_file)
if not audio_path.is_absolute():
raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
audio_path = audio_path.expanduser()

Copilot uses AI. Check for mistakes.
audio_text_pairs.append((audio_path.as_posix(), text))

return audio_text_pairs

Expand Down Expand Up @@ -242,35 +271,22 @@ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)


def get_args():
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
parser.add_argument(
"inp_dir",
type=str,
help="Input CSV with header 'audio_file|text' and absolute wav paths.",
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The help text states "absolute wav paths" but the code accepts any audio file format that soundfile, ffprobe, or torchaudio can handle (not just WAV files). Consider using more generic terminology like "absolute audio file paths" to avoid confusion.

Copilot uses AI. Check for mistakes.
)
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
return parser.parse_args()


def cli():
try:
# Before processing, check if ffprobe is available.
if shutil.which("ffprobe") is None:
print(
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
)

# Usage examples in help text
parser = argparse.ArgumentParser(
description="Prepare and save dataset.",
epilog="""
Examples:
# For fine-tuning (default):
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path

# For pre-training:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain

# With custom worker count:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
""",
)
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
args = parser.parse_args()

args = get_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
except KeyboardInterrupt:
print("\nOperation cancelled by user. Cleaning up...")
Expand Down