Skip to content

Add BestRQ pretraining #873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 242 additions & 8 deletions src/fairseq2/datasets/speech.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -8,19 +8,34 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from pathlib import Pathx

Check failure on line 11 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'pathlib.Pathx' imported but unused
from typing import Final, final

import torch
from typing_extensions import override

from fairseq2.datasets.config import Batching, DataReadOptions
from fairseq2.assets import AssetCard, AssetError

Check failure on line 17 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.assets.AssetCard' imported but unused

Check failure on line 17 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.assets.AssetError' imported but unused
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text

Check failure on line 19 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.data.text.TextTokenizer' imported but unused
from fairseq2.datasets.batching import Batching
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader
from fairseq2.datasets.hub import DatasetHubAccessor
from fairseq2.error import NotSupportedError

Check failure on line 23 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.error.NotSupportedError' imported but unused
from fairseq2.gang import Gang
from fairseq2.models.sequence import SequenceBatch
from fairseq2.typing import DataType
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching

Check failure on line 28 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

redefinition of unused 'Batching' from line 20
from fairseq2.data import (

Check failure on line 29 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.data.CollateOptionsOverride' imported but unused
CollateOptionsOverride,
Collater,
DataPipeline,
DataPipelineBuilder,
FileMapper,
SequenceData,
create_bucket_sizes,
read_sequence,
)


class SpeechDataset(ABC):
Expand Down Expand Up @@ -60,7 +75,7 @@


@dataclass
class SpeechReadOptions(DataReadOptions):

Check failure on line 78 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'DataReadOptions'
dtype: DataType = torch.float32
"""The data type of the decoded audio sequences."""

Expand All @@ -70,14 +85,41 @@

GENERIC_SPEECH_DATASET_FAMILY: Final = "generic_speech"

# TODO: FIX, INFER
npc = 10

@final
class GenericSpeechDataset(SpeechDataset):
"""Represents a generic manifest-based Speech dataset."""

@staticmethod
def from_path(path: Path, name: str) -> GenericSpeechDataset:
return GenericSpeechDataset()
_manifest_dir: Path

Check failure on line 95 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'Path'
_splits: set[str]

def __init__(self, manifest_dir: Path, splits: set[str]) -> None:

Check failure on line 98 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'Path'
"""
:param manifest_dir:
The directory under which the manifest files resides.
:param splits:
The available splits.
"""
self._manifest_dir = manifest_dir
self._splits = splits

@classmethod
def from_path(cls, path: Path) -> GenericSpeechDataset:
"""Load a :class:`GenericSpeechDataset` from ``path``."""
path = path.expanduser().resolve()
if not path.is_dir():
return GenericSpeechDataset(manifest_dir=path.parent, splits={path.stem})

try:
splits = {f.stem for f in path.glob("*.tsv")}
except OSError as ex:
raise RuntimeError(
"The splits cannot be determined. See nested exception for details."
) from ex

return GenericSpeechDataset(path, splits)

@override
def create_reader(
Expand All @@ -87,13 +129,205 @@
min_audio_len: int,
max_audio_len: int,
batching: Batching,
options: SpeechReadOptions | None = None,
*,
dtype: DataType = torch.float32,
min_audio_len: int = 1,

Check failure on line 134 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Duplicate argument "min_audio_len" in function definition
normalize_audio: bool = False,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
drop_remainder: bool = False,
sync_batches: bool = True,
sync_mode: Literal["until_first", "until_last"] = "until_first",
max_num_batches: int | None = None,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
cached_fd_count: int = 1000,
use_fbank: bool = False,
num_fbank_channels: int = 80,
fbank_stride: int = 2,
**extras: Any,
) -> DataPipelineReader[SequenceBatch]:
raise NotSupportedError("not supported yet.")
"""
:param cached_fd_count:
The maximum number of file descriptors to keep open while reading
audio files.
"""
if split not in self._splits:
raise ValueError(
f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits))}"
)

audio_dir = self._retrieve_data_directory(split)

builder = self._read_manifest(split)

# Shuffle examples. Must be consistent across all processes.
if example_shuffle_window != 1:
builder.shuffle(example_shuffle_window, seed)

seed += 1

# Shard.
builder.shard(gang.rank, gang.size, allow_uneven=True)

seed += gang.rank

if isinstance(batching, LengthBatching):
# Bucket by the audio length.
bucket_sizes = create_bucket_sizes(
max_seq_len=max_audio_len,
min_seq_len=min_audio_len,
max_num_elements=batching.max_num_elements,
num_seqs_multiple_of=8,
)

builder.bucket_by_length(
bucket_sizes,
selector="audio_size",
min_data_len=min_audio_len,
skip_below_min_examples=True,
skip_above_max_examples=True,
drop_remainder=drop_remainder,
)
elif isinstance(batching, StaticBatching):
# Filter out out-of-range audios.
def skip(example: dict[str, Any]) -> bool:
audio_len = cast(int, example["audio_size"])

return audio_len >= min_audio_len and audio_len <= max_audio_len

builder.filter(skip)

# Bucket `batch_size` examples.
builder.bucket(batching.batch_size, drop_remainder=drop_remainder)
else:
raise RuntimeError(f"`{batching}` is not supported.")

# Shuffle buckets.
if batch_shuffle_window != 1:
builder.shuffle(batch_shuffle_window, seed)

seed += 1

# Memory map audio files.
file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count)

builder.map(file_mapper, selector="[*].audio")


# Decode audio.
audio_decoder = AudioDecoder(dtype=torch.float32 if normalize_audio else dtype)

builder.map(audio_decoder, selector="[*].audio.data")

# TODO(balioglu): Check/adjust sample size

# Normalize audio if requested.
def normalize(waveform: Tensor) -> Tensor:
with torch.no_grad():
waveform = layer_norm(waveform, waveform.shape)

return waveform.to(dtype)

if normalize_audio:
builder.map(normalize, selector="[*].audio.data.waveform")

if use_fbank:
fbank_converter = WaveformToFbankConverter(
num_mel_bins=num_fbank_channels,
waveform_scale=1.0,
channel_last=True,
standardize=False,
keep_waveform=True,
)
builder.map(fbank_converter, selector="[*].audio.data")

collater = Collater(pad_value=0)

builder.map(collater, num_parallel_calls=npc)

# Return only the first `max_num_batches`.
if max_num_batches is not None:
builder.take(max_num_batches)

# Prefetch `num_prefetch` batches in background.
builder.prefetch(num_prefetch)

# Wrap examples with `Seq2SeqBatch`.
def to_batch(example: dict[str, Any]) -> SequenceBatch:
if use_fbank:
source_data = cast(SequenceData, example["audio"]["data"]["fbank"])
else:
source_data = cast(SequenceData, example["audio"]["data"]["waveform"])

source_seqs, source_padding_mask = get_seqs_and_padding_mask(
source_data, gang.device
)
return SequenceBatch(
seqs=source_seqs,
padding_mask=source_padding_mask,
example=example,
)

pipeline = builder.map(to_batch).and_return()

return DataPipelineReader[SequenceBatch](
pipeline,
gang,
num_accumulate=num_accumulate,
drop_remainder=drop_remainder,
sync_batches=sync_batches,
sync_mode=sync_mode,
)

def _retrieve_data_directory(self, split: str) -> Path:
manifest_file = self._manifest_dir.joinpath(f"{split}.tsv")

try:
with manifest_file.open() as fp:
line = fp.readline().rstrip()
except OSError as ex:
raise DatasetError(
f"{manifest_file} cannot be read. See nested exception for details."
) from ex

try:
return Path(line)
except ValueError:
raise DatasetError(
f"The first line of {manifest_file} must point to a data directory."
) from None

def _read_manifest(self, split: str) -> DataPipelineBuilder:
def read_tsv_file() -> DataPipelineBuilder:
tsv_file = self._manifest_dir.joinpath(f"{split}.tsv")

builder = read_text(tsv_file, rtrim=True, memory_map=True)

builder.skip(1) # Path to the data directory.

field_splitter = StrSplitter(names=["audio", "audio_size"])

builder.map(field_splitter, num_parallel_calls=npc)

return builder

tsv_pipeline = read_tsv_file().and_return()

builder = DataPipeline.zip([tsv_pipeline], flatten=True)

# Cast audio size to integer.
builder.map(int, selector="audio_size")

# TODO(balioglu): Use `cache()` op.
manifest = list(builder.and_return())

return read_sequence(manifest)

@override
def splits(self) -> set[str]:
return set()
return self._splits


get_speech_dataset_hub = DatasetHubAccessor(SpeechDataset)
37 changes: 37 additions & 0 deletions src/fairseq2/models/bestrq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/models/bestrq/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Imports are incorrectly sorted and/or formatted.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.bestrq.factory import BESTRQ_FAMILY as BESTRQ_FAMILY
from fairseq2.models.bestrq.factory import BestRQBuilder as BestRQBuilder
from fairseq2.models.bestrq.factory import BestRQConfig as BestRQConfig
from fairseq2.models.bestrq.factory import (
create_bestrq_model as create_bestrq_model,
)
from fairseq2.models.bestrq.factory import bestrq_arch as bestrq_arch
from fairseq2.models.bestrq.factory import bestrq_archs as bestrq_archs
from fairseq2.models.bestrq.factory import (
bestrq_encoder_arch as bestrq_encoder_arch,
)
from fairseq2.models.bestrq.factory import (
bestrq_encoder_archs as bestrq_encoder_archs,
)
from fairseq2.models.bestrq.masker import RandomNoiseMasker as RandomNoiseMasker
from fairseq2.models.bestrq.model import BestRQFeatures as BestRQFeatures
from fairseq2.models.bestrq.model import BestRQLoss as BestRQLoss
from fairseq2.models.bestrq.model import BestRQModel as BestRQModel
from fairseq2.models.bestrq.model import BestRQOutput as BestRQOutput
from fairseq2.models.bestrq.quantizer import MultiRandomVectorQuantizerOutput

# isort: split

from fairseq2.dependency import DependencyContainer
from fairseq2.models.bestrq.archs import register_archs


def register_bestrq(container: DependencyContainer) -> None:
register_archs()
29 changes: 29 additions & 0 deletions src/fairseq2/models/bestrq/archs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/models/bestrq/archs.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Imports are incorrectly sorted and/or formatted.

Check failure on line 1 in src/fairseq2/models/bestrq/archs.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.bestrq.factory import (
BestRQConfig,
bestrq_arch,
bestrq_encoder_arch,
)
from fairseq2.models.wav2vec2.factory import (
Wav2Vec2EncoderConfig,
)
from fairseq2.nn.transformer import TransformerNormOrder


def register_archs() -> None:
@bestrq_arch("base")
def _base() -> BestRQConfig:
return BestRQConfig()

@bestrq_encoder_arch("base")
def _base_encoder() -> Wav2Vec2EncoderConfig:
config = _base()

return config.encoder_config
Loading
Loading