Skip to content

Mutualize audio readings #1144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main_w2v2_pretraining
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/fairseq2/data/parquet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ class BasicDataLoadingConfig:

# default trivial config will load all columns
fragment_load_config: FragmentLoadingConfig = field(
default=lambda: FragmentLoadingConfig()
default_factory=lambda: FragmentLoadingConfig()
)

# default trivial config applies NO bucketing
table_bucketing_config: TableBucketingConfig = field(
default=lambda: TableBucketingConfig()
default_factory=lambda: TableBucketingConfig()
)


Expand Down
12 changes: 11 additions & 1 deletion src/fairseq2/data/parquet/fragment_loading/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import gc
import tempfile
from copy import deepcopy
from functools import partial
from pathlib import Path
from pickle import dumps, loads
from typing import List, Optional

import numpy as np
import pyarrow as pa
from retrying import retry

Expand Down Expand Up @@ -139,13 +141,21 @@ def __init__(self, config: FragmentLoadingConfig) -> None:

def apply(self, fragment_pipeline: DataPipelineBuilder) -> DataPipelineBuilder:
def load_fn(fragment: pa.dataset.ParquetFileFragment) -> pa.Table | None:
return SafeFragment(fragment).load(
safe_fragment = SafeFragment(fragment)
if np.random.rand() < 0.05:
gc.collect()
# pa.jemalloc_set_decay_ms(10)
pool = pa.default_memory_pool()
pool.release_unused()

table = safe_fragment.load(
columns=self.columns,
add_fragment_traces=self.config.add_fragment_traces,
use_threads=self.config.use_threads,
filters=self.filters,
add_partitioning_columns=True,
)
return table

if self.config.non_deterministic_read:
# keeping if above checks for back-compatibility
Expand Down
217 changes: 44 additions & 173 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,36 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Final, cast, final
from functools import partial
from typing import Any, Dict, Final, cast

import torch
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override

from fairseq2.data import (
CollateOptionsOverride,
Collater,
DataPipeline,
DataPipelineBuilder,
FileMapper,
SequenceData,
create_bucket_sizes,
read_sequence,
)
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import StrSplitter, read_text
from fairseq2.data.text.tokenizers import TextTokenizer
from fairseq2.datasets import (
DataPipelineReader,
DataReader,
DataReadError,
DataReadOptions,
DatasetHubAccessor,
DatasetLoadError,
LengthBatching,
StaticBatching,
UnknownSplitError,
)
from fairseq2.error import NotSupportedError
from fairseq2.datasets.speech import (
GenericSpeechDataset,
ManifestDatasetInterface,
SpeechReadOptions,
)
from fairseq2.gang import Gang
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.typing import DataType


@dataclass(kw_only=True)
class AsrReadOptions(DataReadOptions):
dtype: DataType = torch.float32
"""The data type of the decoded audio sequences."""

normalize_audio: bool = False
"""If ``True``, normalizes audio to have zero mean and unit variance."""
from fairseq2.typing import Device


class AsrDataset(ABC):
Expand All @@ -67,7 +50,7 @@ def create_reader(
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: AsrReadOptions | None = None,
options: SpeechReadOptions | None = None,
) -> DataReader[Seq2SeqBatch]:
"""Create a dataset reader.

Expand All @@ -92,47 +75,33 @@ def splits(self) -> set[str]:
"""Return the set of splits."""


# TODO: FIX, INFER
npc = 10


GENERIC_ASR_DATASET_FAMILY: Final = "generic_asr"

get_asr_dataset_hub = DatasetHubAccessor(AsrDataset)

# TODO: Work in progress!
class GenericAsrDataset(AsrDataset):
"""Represents a generic manifest-based ASR dataset."""

_name: str
_manifest_dir: Path
_splits: set[str]

def __init__(self, name: str, manifest_dir: Path, splits: set[str]) -> None:
"""
:param manifest_dir:
The directory under which the manifest files resides.
:param splits:
The available splits.
"""
self._name = name
self._manifest_dir = manifest_dir
self._splits = splits
class GenericAsrDataset(ManifestDatasetInterface, AsrDataset):
"""Represents a generic manifest-based ASR dataset."""

@staticmethod
def from_path(path: Path, name: str) -> GenericAsrDataset:
path = path.expanduser().resolve()

if not path.is_dir():
return GenericAsrDataset(name, manifest_dir=path.parent, splits={path.stem})
def to_batch(example: Dict[str, Any], device: Device | None = None) -> Seq2SeqBatch:
source_data = cast(SequenceData, example["audio_feature"])
target_data = cast(SequenceData, example["text"])

try:
splits = {f.stem for f in path.glob("*.tsv")}
except OSError as ex:
raise DatasetLoadError(
name, f"The splits under the '{path}' directory of the '{name}' dataset cannot be determined. See the nested exception for details." # fmt: skip
) from ex
source_seqs, source_padding_mask = get_seqs_and_padding_mask(
source_data, device=device
)
target_seqs, target_padding_mask = get_seqs_and_padding_mask(
target_data, device=device
)

return GenericAsrDataset(name, path, splits)
return Seq2SeqBatch(
source_seqs,
source_padding_mask,
target_seqs,
target_padding_mask,
example,
)

@override
def create_reader(
Expand All @@ -142,7 +111,7 @@ def create_reader(
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: AsrReadOptions | None = None,
options: SpeechReadOptions | None = None,
) -> DataPipelineReader[Seq2SeqBatch]:
"""
:param cached_fd_count:
Expand All @@ -153,12 +122,13 @@ def create_reader(
raise UnknownSplitError(self._name, split, self._splits)

if options is None:
options = AsrReadOptions()

options = SpeechReadOptions()
npc = options.npc
seed = options.seed

audio_dir = self._retrieve_data_directory(split)

audio_dir = GenericSpeechDataset._retrieve_data_directory(
self._manifest_dir, self._name, split
)
builder = self._read_manifest(split)

# Shuffle examples. Must be consistent across all processes.
Expand All @@ -171,79 +141,19 @@ def create_reader(
builder.shard(gang.rank, gang.size, allow_uneven=True)

seed += gang.rank

batching = options.batching

if isinstance(batching, LengthBatching):
# Bucket by the audio length.
bucket_sizes = create_bucket_sizes(
min_seq_len=min_audio_len,
max_seq_len=max_audio_len,
max_num_elements=batching.max_num_elements,
num_seqs_multiple_of=8,
)

builder.bucket_by_length(
bucket_sizes,
selector="audio_size",
min_data_len=min_audio_len,
skip_below_min_examples=True,
skip_above_max_examples=True,
drop_remainder=options.drop_remainder,
)
elif isinstance(batching, StaticBatching):
# Filter out out-of-range audios.
def skip(example: dict[str, object]) -> bool:
audio_len = cast(int, example["audio_size"])

return audio_len >= min_audio_len and audio_len <= max_audio_len

builder.filter(skip)

# Bucket `batch_size` examples.
builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder)
else:
raise NotSupportedError(f"`{batching}` is not supported.")

# Shuffle buckets.
if options.batch_shuffle_window != 1:
builder.shuffle(options.batch_shuffle_window, seed)

# Bucketize examples by audio length.
builder = GenericSpeechDataset.add_bucketing_pipeline(
builder, options, max_audio_len, min_audio_len, seed, "audio_size"
)
# Read audios
seed += 1

# Memory map audio files.
cached_fd_count = options.extras.get("cached_fd_count", 1)
if not isinstance(cached_fd_count, int):
raise TypeError(
f"`options.extras['cached_fd_count']` must be of type `int`, but is of type `{type(cached_fd_count)}` instead."
)

file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count)

builder.map(file_mapper, selector="[*].audio")

# Decode audio.
audio_decoder = AudioDecoder(
dtype=torch.float32 if options.normalize_audio else options.dtype
builder = GenericSpeechDataset.add_audio_decoding(builder, options, audio_dir)
builder = GenericSpeechDataset.audio_post_process(
builder, options, GenericSpeechDataset.rename_feature
)

builder.map(audio_decoder, selector="[*].audio.data")

# TODO(balioglu): Check/adjust sample size

# Normalize audio if requested.
def normalize(waveform: Tensor) -> Tensor:
with torch.no_grad():
waveform = layer_norm(waveform, waveform.shape)

return waveform.to(options.dtype)

if options.normalize_audio:
builder.map(normalize, selector="[*].audio.data.waveform")

# Tokenize target text.
text_encoder = tokenizer.create_encoder()

builder.map(text_encoder, selector="[*].text", num_parallel_calls=npc)

# Collate bucketed examples into a batch.
Expand All @@ -263,45 +173,13 @@ def normalize(waveform: Tensor) -> Tensor:
builder.prefetch(options.num_prefetch)

# Wrap examples with `Seq2SeqBatch`.
def to_batch(example: dict[str, Any]) -> Seq2SeqBatch:
source_data = cast(SequenceData, example["audio"]["data"]["waveform"])
target_data = cast(SequenceData, example["text"])

source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data)
target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data)

return Seq2SeqBatch(
source_seqs,
source_padding_mask,
target_seqs,
target_padding_mask,
example,
)

pipeline = builder.map(to_batch).and_return()
pipeline = builder.map(partial(self.to_batch, device=gang.device)).and_return()

return DataPipelineReader[Seq2SeqBatch](
self._name, split, pipeline, gang, options
)

def _retrieve_data_directory(self, split: str) -> Path:
manifest_file = self._manifest_dir.joinpath(f"{split}.tsv")

try:
with manifest_file.open(encoding="utf-8") as fp:
line = fp.readline().rstrip()
except OSError as ex:
raise DataReadError(
self._name, split, f"The {manifest_file} manifest file cannot be read. See the nested exception for details." # fmt: skip
) from ex

try:
return Path(line)
except ValueError:
raise DataReadError(
self._name, split, f"The first line of the '{manifest_file}' manifest file must point to a data directory." # fmt: skip
) from None

def _read_manifest(self, split: str) -> DataPipelineBuilder:
def read_tsv_file() -> DataPipelineBuilder:
tsv_file = self._manifest_dir.joinpath(f"{split}.tsv")
Expand All @@ -312,7 +190,7 @@ def read_tsv_file() -> DataPipelineBuilder:

field_splitter = StrSplitter(names=["audio", "audio_size"])

builder.map(field_splitter, num_parallel_calls=npc)
builder.map(field_splitter)

return builder

Expand All @@ -333,10 +211,3 @@ def read_wrd_file() -> DataPipelineBuilder:
manifest = list(builder.and_return())

return read_sequence(manifest)

@override
def splits(self) -> set[str]:
return self._splits


get_asr_dataset_hub = DatasetHubAccessor(AsrDataset)
Loading
Loading