Skip to content

Gilkeren/mixing #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main_w2v2_pretraining
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/fairseq2/assets/cards/datasets/librilight.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@
name: librilight_asr_10h
dataset_family: generic_asr
data: /checkpoint/mms/shared/assets/debug/librilight

---

name: librilight_asr_10h_unlabeled
dataset_family: generic_speech
data: /checkpoint/mms/shared/assets/debug/librilight
115 changes: 105 additions & 10 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,26 +8,25 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass

from functools import partial
from pathlib import Path
from typing import Any, Final, cast, final
from typing import Any, cast, Final, final, List

Check failure on line 14 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'typing.final' imported but unused

import torch
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override

from fairseq2.data import (
CollateOptionsOverride,
Collater,
create_bucket_sizes,
DataPipeline,
DataPipelineBuilder,
FileMapper,
SequenceData,
create_bucket_sizes,
read_sequence,
SequenceData,
)
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import StrSplitter, read_text
from fairseq2.data.text import read_text, StrSplitter
from fairseq2.data.text.tokenizers import TextTokenizer
from fairseq2.datasets import (
DataPipelineReader,
Expand All @@ -40,11 +39,15 @@
StaticBatching,
UnknownSplitError,
)
from fairseq2.datasets.speech import SpeechDataset, SpeechReadOptions
from fairseq2.error import NotSupportedError
from fairseq2.gang import Gang
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.typing import DataType
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override


@dataclass(kw_only=True)
Expand Down Expand Up @@ -135,7 +138,7 @@
return GenericAsrDataset(name, path, splits)

@override
def create_reader(
def create_pipeline(

Check failure on line 141 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Method "create_pipeline" is marked as an override, but no base method was found with this name
self,
split: str,
tokenizer: TextTokenizer,
Expand Down Expand Up @@ -278,10 +281,24 @@
example,
)

pipeline = builder.map(to_batch).and_return()
pipeline = builder.map(to_batch)
return pipeline

Check failure on line 285 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Incompatible return value type (got "DataPipelineBuilder", expected "DataPipelineReader[Seq2SeqBatch]")

@override
def create_reader(
self,
split: str,
tokenizer: TextTokenizer,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: AsrReadOptions | None = None,
) -> DataPipelineReader[Seq2SeqBatch]:
pipeline = self.create_pipeline(
split, tokenizer, gang, min_audio_len, max_audio_len, options
)
return DataPipelineReader[Seq2SeqBatch](
self._name, split, pipeline, gang, options
self._name, split, pipeline.and_return(), gang, options

Check failure on line 301 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 5 to "DataPipelineReader" has incompatible type "AsrReadOptions | None"; expected "DataReadOptions"

Check failure on line 301 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

"DataPipelineReader[Seq2SeqBatch]" has no attribute "and_return"
)

def _retrieve_data_directory(self, split: str) -> Path:
Expand Down Expand Up @@ -339,4 +356,82 @@
return self._splits


from fairseq2.datasets.speech import SpeechDataset

Check failure on line 359 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

redefinition of unused 'SpeechDataset' from line 42


class MixingAsrDataset:
"""Represents a mixture of labeled and unlabeled datasets."""

def __init__(
self,
datasets: List[AsrDataset | SpeechDataset],
sampling_weights: List[float],
) -> None:
"""
:param datasets:
A list of AsrDatset objects.
"""
self._datasets = datasets
self._sampling_weights = sampling_weights
assert len(datasets) == len(sampling_weights)

def create_reader(
self,
split: str,
tokenizer: TextTokenizer,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: AsrReadOptions | None = None,
) -> DataPipelineReader[Seq2SeqBatch]:

def add_name(name, batch):

Check failure on line 388 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation
batch.example = {"ds_name": name}
return batch

pipelines = []
# assert not options.normalize_audio
speech_options = SpeechReadOptions(
batching=options.batching,

Check failure on line 395 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Item "None" of "AsrReadOptions | None" has no attribute "batching"
dtype=options.dtype,
normalize_audio=options.normalize_audio,
batch_shuffle_window=options.batch_shuffle_window,
num_accumulate=options.num_accumulate,
num_prefetch=options.num_prefetch,
example_shuffle_window=options.example_shuffle_window,
use_fbank=False,
no_padding=False,
npc=10,
seed=options.seed,
extras=options.extras,
)
for ds in self._datasets:
if isinstance(ds, AsrDataset):
pipeline = ds.create_pipeline(
split, tokenizer, gang, min_audio_len, max_audio_len, options
)
elif isinstance(ds, SpeechDataset):
pipeline = ds.create_pipeline(
split, gang, min_audio_len, max_audio_len, speech_options
)

# Add name
pipeline.map(partial(add_name, ds._name))

# Repeat
pipeline = pipeline.repeat().and_return()
pipelines.append(pipeline)

# Mix. Use a different seed in each rank, to make sure batches contain a
# variety of data sources.
mixing_pipeline = DataPipeline.sample(
pipelines, weights=self._sampling_weights, seed=options.seed + gang.rank
).and_return()

# Return reader
return DataPipelineReader[Seq2SeqBatch](
"mixing_asr_dataset", split, mixing_pipeline, gang, options
)


get_asr_dataset_hub = DatasetHubAccessor(AsrDataset)
34 changes: 21 additions & 13 deletions src/fairseq2/datasets/speech.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -10,22 +10,14 @@
from dataclasses import dataclass
from functools import partial, reduce
from pathlib import Path
from typing import Any, Dict, Final, List, final
from typing import Any, Dict, Final, final, List

Check failure on line 13 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'typing.final' imported but unused

import numpy as np
import torch
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override

from fairseq2.data import (
Collater,
DataPipelineBuilder,
FileMapper,
create_bucket_sizes,
)
from fairseq2.data import Collater, create_bucket_sizes, DataPipelineBuilder, FileMapper
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
from fairseq2.data.text import StrSplitter, read_text
from fairseq2.data.text import read_text, StrSplitter
from fairseq2.datasets import (
DataPipelineReader,
DataReader,
Expand All @@ -43,6 +35,9 @@
from fairseq2.models.sequence import SequenceBatch
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.typing import DataType, Device
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override


@torch.no_grad()
Expand Down Expand Up @@ -376,7 +371,7 @@
return builder

@override
def create_reader(
def create_pipeline(

Check failure on line 374 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Method "create_pipeline" is marked as an override, but no base method was found with this name
self,
split: str,
gang: Gang,
Expand All @@ -397,7 +392,7 @@
)

seed = options.seed
npc = options.npc

Check failure on line 395 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

local variable 'npc' is assigned to but never used
no_padding = options.no_padding

audio_dir = self._retrieve_data_directory(split)
Expand Down Expand Up @@ -431,8 +426,21 @@

pipeline = builder.map(
partial(to_batch, no_padding=no_padding, device=gang.device)
).and_return()
)
return pipeline

Check failure on line 430 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Incompatible return value type (got "DataPipelineBuilder", expected "DataPipelineReader[SequenceBatch]")

@override
def create_reader(
self,
split: str,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: SpeechReadOptions | None = None,
) -> DataPipelineReader[SequenceBatch]:
pipeline = self.create_pipeline(
split, gang, min_audio_len, max_audio_len, options
)
return DataPipelineReader[SequenceBatch](
self._name, split, pipeline, gang, options

Check failure on line 445 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 5 to "DataPipelineReader" has incompatible type "SpeechReadOptions | None"; expected "DataReadOptions"

Check failure on line 445 in src/fairseq2/datasets/speech.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 3 to "DataPipelineReader" has incompatible type "DataPipelineReader[SequenceBatch]"; expected "DataPipeline"
)
14 changes: 6 additions & 8 deletions src/fairseq2/models/wav2vec2/asr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -9,23 +9,21 @@
from fairseq2.models.wav2vec2.asr._checkpoint import (
convert_wav2vec2_asr_checkpoint as convert_wav2vec2_asr_checkpoint,
)
from fairseq2.models.wav2vec2.asr._config import (
WAV2VEC2_ASR_MODEL_FAMILY as WAV2VEC2_ASR_MODEL_FAMILY,
)
from fairseq2.models.wav2vec2.asr._config import Wav2Vec2AsrConfig as Wav2Vec2AsrConfig
from fairseq2.models.wav2vec2.asr._config import (
register_wav2vec2_asr_configs as register_wav2vec2_asr_configs,
)
from fairseq2.models.wav2vec2.asr._factory import (
Wav2Vec2AsrFactory as Wav2Vec2AsrFactory,
WAV2VEC2_ASR_MODEL_FAMILY as WAV2VEC2_ASR_MODEL_FAMILY,
Wav2Vec2AsrConfig as Wav2Vec2AsrConfig,
)
from fairseq2.models.wav2vec2.asr._factory import (
create_wav2vec2_asr_model as create_wav2vec2_asr_model,
init_final_projection as init_final_projection,
Wav2Vec2AsrFactory as Wav2Vec2AsrFactory,
)
from fairseq2.models.wav2vec2.asr._model import Wav2Vec2AsrModel as Wav2Vec2AsrModel

init_final_projection
# isort: split

from fairseq2.models import ModelHubAccessor
from fairseq2.models.wav2vec2.asr._model import Wav2Vec2AsrModel as Wav2Vec2AsrModel

get_wav2vec2_asr_model_hub = ModelHubAccessor(Wav2Vec2AsrModel, Wav2Vec2AsrConfig)
24 changes: 19 additions & 5 deletions src/fairseq2/recipes/asr/_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -7,11 +7,9 @@
from __future__ import annotations

import math
from typing import Any, TextIO, final
from typing import Any, final, TextIO

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer
from fairseq2.gang import Gang
Expand All @@ -20,7 +18,11 @@
from fairseq2.models.asr import AsrModel, AsrModelOutput
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch

from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel
from fairseq2.recipes import BaseMetricBag, Model, UnitError
from torch import Tensor
from typing_extensions import override


@final
Expand All @@ -41,7 +43,19 @@
def __call__(
self, batch: Seq2SeqBatch, metric_bag: AsrMetricBag
) -> tuple[Tensor, int]:
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask)
if isinstance(self._model, Wav2Vec2AsrModel):
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask)
else:
# Convert a SequenceBatch to a Seq2SeqBatch
if isinstance(batch, SequenceBatch):
batch = Seq2SeqBatch(
source_seqs=batch.seqs,
source_padding_mask=batch.padding_mask,
target_seqs=None,
target_padding_mask=None,
example=batch.example,
)
input_batch = batch

output = self._forward(input_batch)

Expand All @@ -56,7 +70,7 @@

return loss, batch.batch_size

def _forward(self, batch: SequenceBatch) -> AsrModelOutput:
def _forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput:
return self._model.module(batch) # type: ignore[no-any-return]

@property
Expand Down
Loading
Loading