diff --git a/src/fairseq2/datasets/speech.py b/src/fairseq2/datasets/speech.py index f56b156bd..a67b08e46 100644 --- a/src/fairseq2/datasets/speech.py +++ b/src/fairseq2/datasets/speech.py @@ -8,32 +8,131 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import partial, reduce from pathlib import Path -from typing import Final, final +from typing import Any, Dict, Final, List, final +import numpy as np import torch +from torch import Tensor +from torch.nn.functional import layer_norm from typing_extensions import override +from fairseq2.data import ( + Collater, + DataPipelineBuilder, + FileMapper, + create_bucket_sizes, +) +from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter +from fairseq2.data.text import StrSplitter, read_text from fairseq2.datasets import ( DataPipelineReader, DataReader, + DataReadError, DataReadOptions, DatasetHubAccessor, + DatasetLoadError, + LengthBatching, + StaticBatching, + UnknownSplitError, ) from fairseq2.error import NotSupportedError from fairseq2.gang import Gang +from fairseq2.logging import log from fairseq2.models.sequence import SequenceBatch -from fairseq2.typing import DataType +from fairseq2.nn.padding import get_seqs_and_padding_mask +from fairseq2.typing import DataType, Device + + +@torch.no_grad() +def postprocess(waveform: Tensor, normalize_audio: bool, dtype: DataType) -> Tensor: + if waveform.dim() == 2: + # reduce channels inplace to save the memory + size = waveform.size(1) + result = reduce( + torch.Tensor.add_, [waveform[:, i] for i in range(1, size)], waveform[:, 0] + ) + waveform = result + waveform /= size + + if normalize_audio: + waveform = layer_norm(waveform, waveform.shape) + + return waveform.to(dtype) + + +class AudioCropper: + + audio_feature: str = "audio_feature" + + def __init__( + self, max_audio_len: int, seed: int, crop_to_batch_minimal_size: bool = False + ) -> None: + self.rng: np.random.RandomState = np.random.RandomState(seed) + self.max_audio_len: int = max_audio_len + self.crop_to_batch_minimal_size: bool = crop_to_batch_minimal_size + + def crop_audios_in_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if self.crop_to_batch_minimal_size: + min_audio_len_batch = min( + (item[self.audio_feature].size(0) for item in batch) + ) + crop_size = min(self.max_audio_len, min_audio_len_batch) + else: + crop_size = self.max_audio_len + + for item in batch: + audio = item[self.audio_feature] + audio_size = audio.size(0) + if audio_size > crop_size: + start = self.rng.randint(0, audio_size - crop_size + 1) + item[self.audio_feature] = audio[start : start + crop_size] + return batch + + +def rename_feature(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + for example in batch: + if "fbank" in example["audio"]["data"]: + example["audio_feature"] = example["audio"]["data"].pop("fbank") + elif "waveform" in example["audio"]["data"]: + example["audio_feature"] = example["audio"]["data"].pop("waveform") + return batch + + +def to_batch( + example: dict[str, Any], no_padding: bool, device: Device +) -> SequenceBatch: + audio_feature = example["audio_feature"] + if no_padding: + seqs = audio_feature.to(device) + padding_mask = None + else: + seqs, padding_mask = get_seqs_and_padding_mask(audio_feature, device=device) + + return SequenceBatch(seqs, padding_mask, example=example) @dataclass(kw_only=True) class SpeechReadOptions(DataReadOptions): + dtype: DataType = torch.float32 """The data type of the decoded audio sequences.""" normalize_audio: bool = False """If ``True``, normalizes audio to have zero mean and unit variance.""" + use_fbank: bool = False + """If ``True``, use fbank features instead of waveform.""" + + no_padding: bool = True + """If ``True``, all elements in the batch will be truncated to by batch minimal length. + Therefore, no padding will be applied to the batch. + """ + + npc: int = 10 + """The number of parallel calls to use in the pipeline.""" + class SpeechDataset(ABC): """Represents a speech dataset.""" @@ -45,7 +144,7 @@ def create_reader( gang: Gang, min_audio_len: int, max_audio_len: int, - options: SpeechReadOptions | None = None, + options: SpeechReadOptions = SpeechReadOptions(), ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -69,15 +168,96 @@ def splits(self) -> set[str]: GENERIC_SPEECH_DATASET_FAMILY: Final = "generic_speech" +get_speech_dataset_hub = DatasetHubAccessor(SpeechDataset) @final class GenericSpeechDataset(SpeechDataset): """Represents a generic manifest-based Speech dataset.""" + _name: str + _manifest_dir: Path + _splits: set[str] + + def __init__(self, name: str, manifest_dir: Path, splits: set[str]) -> None: + """ + :param manifest_dir: + The directory under which the manifest files resides. + :param splits: + The available splits. + """ + self._name = name + self._manifest_dir = manifest_dir + self._splits = splits + @staticmethod def from_path(path: Path, name: str) -> GenericSpeechDataset: - return GenericSpeechDataset() + path = path.expanduser().resolve() + + if not path.is_dir(): + return GenericSpeechDataset( + name, manifest_dir=path.parent, splits={path.stem} + ) + + try: + splits = {f.stem for f in path.glob("*.tsv")} + except OSError as ex: + raise DatasetLoadError( + name, f"The splits under the '{path}' directory of the '{name}' dataset cannot be determined. See the nested exception for details." # fmt: skip + ) from ex + + return GenericSpeechDataset(name, path, splits) + + @override + def splits(self) -> set[str]: + return self._splits + + def _retrieve_data_directory(self, split: str) -> Path: + manifest_file = self._manifest_dir.joinpath(f"{split}.tsv") + + try: + with manifest_file.open(encoding="utf-8") as fp: + line = fp.readline().rstrip() + except OSError as ex: + raise DataReadError( + self._name, split, f"The {manifest_file} manifest file cannot be read. See the nested exception for details." # fmt: skip + ) from ex + + try: + return Path(line) + except ValueError: + raise DataReadError( + self._name, split, f"The first line of the '{manifest_file}' manifest file must point to a data directory." # fmt: skip + ) from None + + def _read_manifest( + self, split: str, max_audio_len: int, min_audio_len: int, audio_dir: Path | None + ) -> DataPipelineBuilder: + """ + we only apply min_audio_len filter here, + longer audio will be croped to max_audio_len latter + """ + tsv_file = self._manifest_dir.joinpath(f"{split}.tsv") + + builder = read_text( + tsv_file, rtrim=True, memory_map=True, block_size=10 * 1024 * 1024 + ) + + if audio_dir is not None: + builder.skip(1) # Path to the data directory. + + field_splitter = StrSplitter(names=["audio", "audio_size"]) + + builder.map(field_splitter) + + builder.map( + lambda x: min(int(x), max_audio_len), + selector="audio_size", + ) + + builder.filter(lambda sample: sample["audio_size"] >= min_audio_len) + + return builder @override def create_reader( @@ -88,11 +268,142 @@ def create_reader( max_audio_len: int, options: SpeechReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: - raise NotSupportedError("not supported yet.") - @override - def splits(self) -> set[str]: - return set() + if split not in self._splits: + raise UnknownSplitError(self._name, split, self._splits) + if options is None: + options = SpeechReadOptions() -get_speech_dataset_hub = DatasetHubAccessor(SpeechDataset) + log.info( + f"Creating a reader for the <{split}> split of the <{self._name}>" + f" dataset with the following options:/n {options}." + ) + + seed = options.seed + npc = options.npc + no_padding = options.no_padding + + audio_dir = self._retrieve_data_directory(split) + builder = self._read_manifest(split, max_audio_len, min_audio_len, audio_dir) + + if options.example_shuffle_window != 1: + # builder.prefetch(options.example_shuffle_window * options.num_prefetch) + builder.shuffle(options.example_shuffle_window, seed) + seed += 1 + + # Shard. + builder.shard(gang.rank, gang.size, allow_uneven=True) + seed += gang.rank + + batching = options.batching + + if isinstance(batching, LengthBatching): + # Bucket by the audio length. + max_num_elements = batching.max_num_elements + num_seqs_multiple_of = options.extras.get("num_seqs_multiple_of", 8) + assert isinstance( + num_seqs_multiple_of, int + ), "num_seqs_multiple_of must be an integer" + assert num_seqs_multiple_of > 0, "num_seqs_multiple_of must be positive" + + if max_num_elements % max_audio_len != 0: + max_num_elements = (max_num_elements // max_audio_len) * max_audio_len + log.warning(f"`max_num_elements` is rounded to {max_num_elements}") + + bucket_sizes = create_bucket_sizes( + min_seq_len=min_audio_len, + max_seq_len=max_audio_len, + max_num_elements=max_num_elements, + num_seqs_multiple_of=num_seqs_multiple_of, + ) + + builder.bucket_by_length( + bucket_sizes, + selector="audio_size", + min_data_len=min_audio_len, + skip_below_min_examples=True, # this should be neutral + skip_above_max_examples=True, # this should be neutral + drop_remainder=options.drop_remainder, + ) + elif isinstance(batching, StaticBatching): + if options.no_padding: + raise NotSupportedError( + "no_padding is not supported for static batching" + ) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") + + # Shuffle buckets. + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) + seed += 1 + + # Memory map audio files. + cached_fd_count = options.extras.get("cached_fd_count", 100) + if not isinstance(cached_fd_count, int): + raise TypeError( + f"`options.extras['cached_fd_count']` must be of type `int`, but is of type `{type(cached_fd_count)}` instead." + ) + + file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count) + + builder.map(file_mapper, selector="[*].audio") + + # Decode audio. + audio_decoder = AudioDecoder( + dtype=torch.float32 if options.normalize_audio else options.dtype + ) + builder.map(audio_decoder, selector="[*].audio.data", num_parallel_calls=npc) + + if options.use_fbank: + fbank_converter = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=True, + dtype=options.dtype, + ) + + builder.map( + fbank_converter, selector="[*].audio.data", num_parallel_calls=npc + ) + else: + builder.map( + partial( + postprocess, + normalize_audio=options.normalize_audio, + dtype=options.dtype, + ), + selector="[*].audio.data.waveform", + ) + + # select the audio feature at the top level + builder.map(rename_feature) + + # Crop long audios to `max_audio_len`. + audio_cropper = AudioCropper( + max_audio_len, + seed=seed, + crop_to_batch_minimal_size=no_padding, + ) + builder.map(audio_cropper.crop_audios_in_batch) + + # Collate batched examples into a batch. + collater = Collater(pad_value=None if no_padding else 0) + builder.map(collater) + + # Return only the first `max_num_batches`. + if options.max_num_batches is not None: + builder.take(options.max_num_batches) + + builder.prefetch(options.num_prefetch) + + pipeline = builder.map( + partial(to_batch, no_padding=no_padding, device=gang.device) + ).and_return() + + return DataPipelineReader[SequenceBatch]( + self._name, split, pipeline, gang, options + ) diff --git a/src/fairseq2/datasets/speech_parquet.py b/src/fairseq2/datasets/speech_parquet.py new file mode 100644 index 000000000..edf652b59 --- /dev/null +++ b/src/fairseq2/datasets/speech_parquet.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Any, Dict, Final, List, Set, final + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from numpy.typing import NDArray +from typing_extensions import override + +from fairseq2.data import ( + Collater, + MemoryBlock, + create_bucket_sizes, + read_sequence, +) +from fairseq2.data.audio import ( + AudioDecoder, + AudioDecoderOutput, + WaveformToFbankConverter, +) +from fairseq2.data.parquet import ( + FragmentLoadingConfig, + FragmentStreamingConfig, + NamedColumns, + ParquetFragmentLoader, + ParquetFragmentStreamer, +) +from fairseq2.datasets import ( + DataPipelineReader, + LengthBatching, + StaticBatching, + UnknownSplitError, +) +from fairseq2.datasets.speech import ( + AudioCropper, + SpeechDataset, + SpeechReadOptions, + postprocess, + to_batch, +) +from fairseq2.error import NotSupportedError +from fairseq2.gang import Gang +from fairseq2.logging import log +from fairseq2.models.sequence import SequenceBatch + + +def rename_feature(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + for example in batch: + if "fbank" in example["audio"]: + example["audio_feature"] = example["audio"].pop("fbank") + elif "waveform" in example["audio"]: + example["audio_feature"] = example["audio"].pop("waveform") + return batch + + +PARQUET_SPEECH_DATASET_FAMILY: Final = "generic_parquet_speech" + + +@dataclass +class DefaultAudioSchema(NamedColumns): + audio: str = "audio_bytes" + length: str = "length" + size: str | None = "size" # size of the audio in bytes : len(audio_bytes) + extra_columns: List[str] | None = None + + +@final +class GenericSpeechParquetDataset(SpeechDataset): + """Represents a generic manifest-based Speech dataset.""" + + _name: str + _dataset: pq.ParquetDataset + _splits: set[str] + split_column: str = "split" + + nb_samples_per_fragment = 1000 # FIXME: detect it from the dataset metadata + max_num_batches: int = 1000 + max_num_examples: int = 2_000_000 + pa_cpu_count: int = 20 + + def __init__(self, name: str, dataset: pq.ParquetDataset, splits: set[str]) -> None: + self._dataset = dataset + self._splits = splits + self._name = name + + @classmethod + def from_path( + cls, + path: Path | str | List[str | Path], + name: str, + filesystem: Any | None = None, + ) -> GenericSpeechParquetDataset: + + # from stopes.fb_config import get_filesystem_from_path + # path, filesystem = get_filesystem_from_path(path) # type: ignore + datasest = pq.ParquetDataset(path, filesystem=filesystem) # type: ignore + + assert isinstance(datasest, pq.ParquetDataset) + partition_columns: List[str] = [] + if datasest.partitioning is not None: + partition_columns = datasest.partitioning.schema.names + + splits: Set[str] = set() + if datasest.partitioning is not None and cls.split_column in partition_columns: + idx = partition_columns.index(cls.split_column) + _splits = datasest.partitioning.dictionaries[idx] + if _splits is None: + splits = set() + else: + splits = set(_splits.to_pylist()) + + return GenericSpeechParquetDataset(name, datasest, splits) + + @override + def splits(self) -> set[str]: + return self._splits + + @override + def create_reader( + self, + split: str, + gang: Gang, + min_audio_len: int, + max_audio_len: int, + options: SpeechReadOptions | None = None, + ) -> DataPipelineReader[SequenceBatch]: + assert min_audio_len <= max_audio_len, "min_audio_len must be <= max_audio_len" + + if split not in self._splits: + raise UnknownSplitError(self._name, split, self._splits) + + if options is None: + options = SpeechReadOptions() + + log.info( + f"Creating a reader for the <{split}> split of the <{self._name}>" + f" dataset with the following options:/n {options}." + ) + + seed = options.seed + npc = options.npc + no_padding = options.no_padding + + pa_cpu_count = int(options.extras.get("pa_cpu_count", self.pa_cpu_count)) # type: ignore + pa.set_cpu_count(pa_cpu_count) + pa.set_io_thread_count(pa_cpu_count) + + # Streaming + partition_filters = options.extras.get("partition_filters", None) + parquet_files: List[str] = self._dataset.files # type: ignore + fragment_config = FragmentStreamingConfig( + parquet_path=parquet_files, + filesystem=self._dataset.filesystem, + nb_epochs=(None if split == "train" else 1), + partition_filters=partition_filters, # type: ignore + split_to_row_groups=True, + files_circular_shift=True, + seed=seed, + fragment_shuffle_window=max( + 100, options.example_shuffle_window // self.nb_samples_per_fragment + ), + ) + fragement_builder = ParquetFragmentStreamer( + config=fragment_config + ).build_pipeline(rank=gang.rank, world_size=gang.size) + + seed += gang.rank + + # we want to give less compute to the loader compared to the audio decoder + num_parallel_fragments = options.extras.get( + "num_parallel_fragments", max(npc // 3, 1) + ) + assert isinstance(num_parallel_fragments, int) + assert num_parallel_fragments > 0, "num_parallel_fragments must be > 0" + + columns = options.extras.get("columns", DefaultAudioSchema()) + if columns is not None: + assert isinstance(columns, NamedColumns) + + loading_config = FragmentLoadingConfig( + columns=columns, + add_fragment_traces=False, + num_parallel_fragments=num_parallel_fragments, + nb_prefetch=options.num_prefetch, + non_deterministic_read=True, + drop_null=False, + ) + + # load data in memory + builder = ParquetFragmentLoader(config=loading_config).apply(fragement_builder) + + # dispatch table into examples + builder = builder.yield_from( + lambda table: read_sequence( + table.to_pandas().to_dict(orient="records") + ).and_return() + ) + + # truncate length to max_audio_len + builder = builder.filter(lambda x: int(x["length"]) >= min_audio_len) + builder = builder.map(lambda x: min(max_audio_len, x), selector="length") + + # shuffle examples in memory + if options.example_shuffle_window != 1: + # FIXME: make it configurable, we need some upper bound to avoid OOM in cpu + example_shuffle_window = min( + options.example_shuffle_window, self.max_num_examples + ) + builder = builder.prefetch(int(2 * example_shuffle_window)) + builder = builder.shuffle(example_shuffle_window, seed=seed) + seed += 1 + + batching = options.batching + + if isinstance(batching, LengthBatching): + # Bucket by the audio length. + max_num_elements = batching.max_num_elements + + num_seqs_multiple_of = options.extras.get("num_seqs_multiple_of", 8) + assert isinstance( + num_seqs_multiple_of, int + ), "num_seqs_multiple_of must be an integer" + assert num_seqs_multiple_of > 0, "num_seqs_multiple_of must be positive" + + if max_num_elements % max_audio_len != 0: + max_num_elements = (max_num_elements // max_audio_len) * max_audio_len + log.warning(f"`max_num_elements` is rounded to {max_num_elements}") + + bucket_sizes = create_bucket_sizes( + min_seq_len=min_audio_len, + max_seq_len=max_audio_len, + max_num_elements=max_num_elements, + num_seqs_multiple_of=8, + ) + + builder.bucket_by_length( + bucket_sizes, + selector="length", + min_data_len=min_audio_len, + skip_below_min_examples=True, # this should be neutral + skip_above_max_examples=True, # this should be neutral + drop_remainder=options.drop_remainder, + ) + elif isinstance(batching, StaticBatching): + if options.no_padding: + raise NotSupportedError( + "no_padding is not supported for static batching" + ) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") + + # Shuffle buckets. + if options.batch_shuffle_window != 1: + batch_shuffle_window = min( + options.batch_shuffle_window, self.max_num_batches + ) + builder.shuffle(batch_shuffle_window, seed) + seed += 1 + + # Decode audio. + audio_decoder = AudioDecoder( + dtype=torch.float32 if options.normalize_audio else options.dtype + ) + + def decoded_audio(_bytes: NDArray[np.int8]) -> AudioDecoderOutput: + return audio_decoder(MemoryBlock(_bytes.tobytes())) + + builder.map(decoded_audio, selector="[*].audio", num_parallel_calls=npc) + + if options.use_fbank: + fbank_converter = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=True, + dtype=options.dtype, + ) + + builder.map(fbank_converter, selector="[*].audio", num_parallel_calls=npc) + else: + builder.map( + partial( + postprocess, + normalize_audio=options.normalize_audio, + dtype=options.dtype, + ), + selector="[*].audio.waveform", + ) + + # select the audio feature at the top level + builder.map(rename_feature) + + # Crop long audios to `max_audio_len`. + audio_cropper = AudioCropper( + max_audio_len, + seed=seed, + crop_to_batch_minimal_size=no_padding, + ) + builder.map(audio_cropper.crop_audios_in_batch) + + # Collate batched examples into a batch. + collater = Collater(pad_value=None if no_padding else 0) + builder.map(collater) + + # Return only the first `max_num_batches`. + if options.max_num_batches is not None: + builder.take(options.max_num_batches) + + builder.prefetch(options.num_prefetch) + + pipeline = builder.map( + partial(to_batch, no_padding=no_padding, device=gang.device) + ).and_return() + + return DataPipelineReader[SequenceBatch]( + self._name, split, pipeline, gang, options, strict_state=False + ) diff --git a/src/fairseq2/recipes/wav2vec2/_train.py b/src/fairseq2/recipes/wav2vec2/_train.py index 0c6f8cc9a..f39bd5c7a 100644 --- a/src/fairseq2/recipes/wav2vec2/_train.py +++ b/src/fairseq2/recipes/wav2vec2/_train.py @@ -48,19 +48,16 @@ RegimeSection, TrainerSection, ) -from fairseq2.typing import CPU -from fairseq2.utils.rng import manual_seed -from fairseq2.utils.structured import structure -from fairseq2.utils.validation import validate - -# isort: split - from fairseq2.recipes.wav2vec2._common import ( Wav2Vec2Criterion, Wav2Vec2LossSection, Wav2Vec2MetricBag, ) from fairseq2.recipes.wav2vec2._eval import Wav2Vec2EvalUnit +from fairseq2.typing import CPU +from fairseq2.utils.rng import manual_seed +from fairseq2.utils.structured import structure +from fairseq2.utils.validation import validate @dataclass(kw_only=True) @@ -141,22 +138,28 @@ class Wav2Vec2TrainDatasetSection(DatasetSection): normalize_audio: bool = False """If ``True``, normalizes audio to have zero mean and unit variance.""" - example_shuffle_window: int = 0 + use_fbank: bool = False + + no_padding: bool = True + + example_shuffle_window: int = 500_000 """The size of the sliding window for shuffling examples.""" - batch_shuffle_window: int = 0 + batch_shuffle_window: int = 50_000 """The size of the sliding window for shuffling batches.""" num_prefetch: int = 4 """The number of batches to prefetch in background.""" + npc: int = 10 + extras: dict[str, object] = field(default_factory=dict) """The dataset-specific extra options.""" def register_wav2vec2_train_configs(context: RuntimeContext) -> None: registry = context.get_config_registry(Wav2Vec2TrainConfig) - + # wav2vec2_archs = context.get_config_registry(Wav2Vec2Config) preset = registry.decorator @preset("base_960h") @@ -177,7 +180,7 @@ def large_960h() -> Wav2Vec2TrainConfig: config.model.arch = "large" config.model.config = {"encoder_config": {"first_pass_dropout_p": 0.1}} config.dataset.max_audio_len = 320_000 - config.dataset.max_num_elements = 1_200_000 + config.dataset.max_num_elements = 1_600_000 config.optimizer.config.lr = 3e-04 config.lr_scheduler.config.num_warmup_steps = 20_000 config.regime.num_steps = 250_000 @@ -241,6 +244,10 @@ def load_wav2vec2_trainer( batch_shuffle_window=config.dataset.batch_shuffle_window, num_accumulate=config.trainer.gradient_accumulation, num_prefetch=config.dataset.num_prefetch, + example_shuffle_window=config.dataset.example_shuffle_window, + use_fbank=config.dataset.use_fbank, + no_padding=config.dataset.no_padding, + npc=config.dataset.npc, seed=seed, extras=config.dataset.extras, ) @@ -248,9 +255,9 @@ def load_wav2vec2_trainer( data_reader = dataset.create_reader( config.dataset.train_split, gangs.dp, - config.dataset.min_audio_len, - config.dataset.max_audio_len, - read_options, + min_audio_len=config.dataset.min_audio_len, + max_audio_len=config.dataset.max_audio_len, + options=read_options, ) seed += 1 @@ -265,6 +272,10 @@ def load_wav2vec2_trainer( normalize_audio=config.dataset.normalize_audio, sync_mode=SyncMode.UNTIL_LAST, num_prefetch=config.dataset.num_prefetch, + example_shuffle_window=config.dataset.example_shuffle_window, + use_fbank=config.dataset.use_fbank, + no_padding=config.dataset.no_padding, + npc=config.dataset.npc, seed=seed, extras=config.dataset.extras, ) @@ -272,9 +283,9 @@ def load_wav2vec2_trainer( valid_data_reader = dataset.create_reader( config.dataset.valid_split, gangs.dp, - config.dataset.min_audio_len, - config.dataset.max_audio_len, - read_options, + min_audio_len=config.dataset.min_audio_len, + max_audio_len=config.dataset.max_audio_len, + options=read_options, ) valid_units = [valid_unit] diff --git a/src/fairseq2/setup/_datasets.py b/src/fairseq2/setup/_datasets.py index 018cac325..2457f1e0b 100644 --- a/src/fairseq2/setup/_datasets.py +++ b/src/fairseq2/setup/_datasets.py @@ -24,6 +24,22 @@ GenericPreferenceDataset, ) from fairseq2.datasets.speech import GENERIC_SPEECH_DATASET_FAMILY, GenericSpeechDataset +from fairseq2.logging import log + +try: + from fairseq2.datasets.speech_parquet import ( + PARQUET_SPEECH_DATASET_FAMILY, + GenericSpeechParquetDataset, + ) +except ImportError as e: + PARQUET_SPEECH_DATASET_FAMILY = None # type: ignore + GenericSpeechParquetDataset = None # type: ignore + if "No module named 'pyarrow'" in str(e): + log.warning( + "PyArrow is not installed. Please install it with `pip install fairseq2[arrow]`." + ) + else: + log.warning(f"Failed to import speech_parquet module: {e}") from fairseq2.datasets.text import GENERIC_TEXT_DATASET_FAMILY, GenericTextDataset from fairseq2.registry import Registry @@ -61,6 +77,12 @@ def register_dataset_families(context: RuntimeContext) -> None: GenericSpeechDataset, GenericSpeechDataset.from_path, ) + if PARQUET_SPEECH_DATASET_FAMILY is not None and GenericSpeechParquetDataset is not None: + registrar.register_family( + PARQUET_SPEECH_DATASET_FAMILY, + GenericSpeechParquetDataset, + GenericSpeechParquetDataset.from_path, + ) registrar.register_family( GENERIC_TEXT_DATASET_FAMILY,