Skip to content

Gilkeren/debug2 #1146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main_w2v2_pretraining
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/fairseq2/assets/cards/datasets/librilight.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,15 @@
name: librilight_asr_10h
dataset_family: generic_asr
data: /checkpoint/mms/shared/assets/debug/librilight

---

name: librilight_asr_10h_unlabeled
dataset_family: generic_speech
data: /checkpoint/mms/shared/assets/debug/librilight

---

name: librilight_asr_10h_inference
dataset_family: speech_inference
data: /checkpoint/mms/shared/assets/debug/librilight
50 changes: 27 additions & 23 deletions src/fairseq2/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,47 @@

import os
import sys
from signal import SIG_DFL, SIGINT, raise_signal, signal
from signal import raise_signal, SIG_DFL, SIGINT, signal

Check failure on line 11 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'signal.signal' imported but unused

Check failure on line 11 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'signal.SIGINT' imported but unused

Check failure on line 11 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'signal.SIG_DFL' imported but unused

Check failure on line 11 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'signal.raise_signal' imported but unused

import torch

Check failure on line 13 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'torch' imported but unused
from torch.cuda import OutOfMemoryError

from fairseq2 import setup_fairseq2
from fairseq2.cli.utils.rich import create_rich_progress_reporter
from fairseq2.error import ContractError, InternalError
from fairseq2.extensions import ExtensionError
from fairseq2.logging import LoggingSetupError, log
from fairseq2.setup import SetupError
from fairseq2.utils.env import InvalidEnvironmentVariableError, get_rank

# isort: split

from fairseq2.cli._logging import setup_logging
from fairseq2.cli._setup import setup_cli
from fairseq2.cli.utils.rich import create_rich_progress_reporter
from fairseq2.error import ContractError, InternalError

Check failure on line 22 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.error.InternalError' imported but unused

Check failure on line 22 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'fairseq2.error.ContractError' imported but unused
from fairseq2.extensions import ExtensionError
from fairseq2.logging import log, LoggingSetupError
from fairseq2.setup import SetupError
from fairseq2.utils.env import get_rank, InvalidEnvironmentVariableError
from torch.cuda import OutOfMemoryError

Check failure on line 27 in src/fairseq2/cli/_main.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'torch.cuda.OutOfMemoryError' imported but unused


def main() -> None:
"""Runs the command line fairseq2 program."""
exit_code = 1

try:
exit_code = _run()
except KeyboardInterrupt:
log.info("Command canceled!")
# try:
exit_code = _run()
# except KeyboardInterrupt:
# log.info("Command canceled!")

signal(SIGINT, SIG_DFL)
# signal(SIGINT, SIG_DFL)

raise_signal(SIGINT)
except OutOfMemoryError:
s = torch.cuda.memory_summary()
# raise_signal(SIGINT)
# except OutOfMemoryError:
# s = torch.cuda.memory_summary()

log.exception("CUDA out of memory. See logged memory stats.\n{}", s)
except InternalError:
log.exception("Command failed with an unexpected internal error. Please file a bug report.") # fmt: skip
except ContractError:
log.exception("Command failed with an unexpected internal error caused by an extension. Please file a bug report to the corresponding extension author.") # fmt: skip
except Exception:
log.exception("Command failed with an unexpected error. See the logged stack trace for details.") # fmt: skip
# log.exception("CUDA out of memory. See logged memory stats.\n{}", s)
# except InternalError:
# log.exception("Command failed with an unexpected internal error. Please file a bug report.") # fmt: skip
# except ContractError:
# log.exception("Command failed with an unexpected internal error caused by an extension. Please file a bug report to the corresponding extension author.") # fmt: skip
# except Exception:
# log.exception("Command failed with an unexpected error. See the logged stack trace for details.") # fmt: skip

sys.exit(exit_code)

Expand Down Expand Up @@ -84,3 +84,7 @@
return 1

return cli.run(context)


if __name__ == "__main__":
main()
24 changes: 12 additions & 12 deletions src/fairseq2/cli/_setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -7,6 +7,10 @@
from __future__ import annotations

from fairseq2.chatbots import UnknownChatbotError

# isort: split

from fairseq2.cli._cli import Cli
from fairseq2.cli.commands.assets import ListAssetsHandler, ShowAssetHandler
from fairseq2.cli.commands.chatbot import RunChatbotHandler
from fairseq2.cli.commands.llama import (
Expand Down Expand Up @@ -61,37 +65,33 @@
from fairseq2.recipes.lm import (
InstructionFinetuneConfig,
LMLossEvalConfig,
POFinetuneConfig,
TextGenerateConfig,
load_instruction_finetuner,
load_lm_loss_evaluator,
load_po_finetuner,
load_text_generator,
POFinetuneConfig,
TextGenerateConfig,
)
from fairseq2.recipes.mt import (
MTEvalConfig,
MTTrainConfig,
TextTranslateConfig,
load_mt_evaluator,
load_mt_trainer,
load_text_translator,
MTEvalConfig,
MTTrainConfig,
TextTranslateConfig,
)
from fairseq2.recipes.wav2vec2 import (
Wav2Vec2EvalConfig,
Wav2Vec2TrainConfig,
load_wav2vec2_evaluator,
load_wav2vec2_trainer,
Wav2Vec2EvalConfig,
Wav2Vec2TrainConfig,
)
from fairseq2.recipes.wav2vec2.asr import (
Wav2Vec2AsrTrainConfig,
load_wav2vec2_asr_trainer,
Wav2Vec2AsrTrainConfig,
)
from fairseq2.utils.validation import ValidationError

# isort: split

from fairseq2.cli._cli import Cli


def setup_cli(context: RuntimeContext) -> Cli:
from fairseq2 import __version__
Expand Down
8 changes: 4 additions & 4 deletions src/fairseq2/datasets/_data_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,18 +8,18 @@

from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping
from typing import TypeVar, final
from typing import final, TypeVar

import torch
from typing_extensions import Self, override

from fairseq2.data import DataPipeline, DataPipelineError
from fairseq2.gang import Gang, GangError, all_sum
from fairseq2.logging import log

# isort: split

from fairseq2.datasets._config import DataReadOptions, SyncMode
from fairseq2.gang import all_sum, Gang, GangError
from fairseq2.logging import log
from typing_extensions import override, Self

BatchT_co = TypeVar("BatchT_co", covariant=True)

Expand Down
115 changes: 105 additions & 10 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,26 +8,25 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass

from functools import partial

Check failure on line 12 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'functools.partial' imported but unused
from pathlib import Path
from typing import Any, Final, cast, final
from typing import Any, cast, Final, final, List

Check failure on line 14 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'typing.final' imported but unused

import torch
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override

from fairseq2.data import (
CollateOptionsOverride,
Collater,
create_bucket_sizes,
DataPipeline,
DataPipelineBuilder,
FileMapper,
SequenceData,
create_bucket_sizes,
read_sequence,
SequenceData,
)
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import StrSplitter, read_text
from fairseq2.data.text import read_text, StrSplitter
from fairseq2.data.text.tokenizers import TextTokenizer
from fairseq2.datasets import (
DataPipelineReader,
Expand All @@ -40,11 +39,15 @@
StaticBatching,
UnknownSplitError,
)
from fairseq2.datasets.speech import SpeechDataset, SpeechReadOptions
from fairseq2.error import NotSupportedError
from fairseq2.gang import Gang
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.typing import DataType
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override


@dataclass(kw_only=True)
Expand Down Expand Up @@ -135,7 +138,7 @@
return GenericAsrDataset(name, path, splits)

@override
def create_reader(
def create_pipeline(

Check failure on line 141 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Method "create_pipeline" is marked as an override, but no base method was found with this name
self,
split: str,
tokenizer: TextTokenizer,
Expand Down Expand Up @@ -278,10 +281,24 @@
example,
)

pipeline = builder.map(to_batch).and_return()
pipeline = builder.map(to_batch)
return pipeline

Check failure on line 285 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Incompatible return value type (got "DataPipelineBuilder", expected "DataPipelineReader[Seq2SeqBatch]")

@override
def create_reader(
self,
split: str,
tokenizer: TextTokenizer,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: AsrReadOptions | None = None,
) -> DataPipelineReader[Seq2SeqBatch]:
pipeline = self.create_pipeline(
split, tokenizer, gang, min_audio_len, max_audio_len, options
)
return DataPipelineReader[Seq2SeqBatch](
self._name, split, pipeline, gang, options
self._name, split, pipeline.and_return(), gang, options

Check failure on line 301 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 5 to "DataPipelineReader" has incompatible type "AsrReadOptions | None"; expected "DataReadOptions"

Check failure on line 301 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

"DataPipelineReader[Seq2SeqBatch]" has no attribute "and_return"
)

def _retrieve_data_directory(self, split: str) -> Path:
Expand Down Expand Up @@ -339,4 +356,82 @@
return self._splits


from fairseq2.datasets.speech import SpeechDataset


class MixingAsrDataset:
"""Represents a mixture of labeled and unlabeled datasets."""

def __init__(
self,
datasets: List[AsrDataset | SpeechDataset],
sampling_weights: List[float],
) -> None:
"""
:param datasets:
A list of AsrDatset objects.
"""
self._datasets = datasets
self._sampling_weights = sampling_weights
assert len(datasets) == len(sampling_weights)

def create_reader(
self,
split: str,
tokenizer: TextTokenizer,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
options: AsrReadOptions | None = None,
) -> DataPipelineReader[Seq2SeqBatch]:

def add_name(name, batch):

Check failure on line 388 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation
batch.example = {"ds_name": name}
return batch

pipelines = []
# assert not options.normalize_audio
speech_options = SpeechReadOptions(
batching=options.batching,

Check failure on line 395 in src/fairseq2/datasets/asr.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Item "None" of "AsrReadOptions | None" has no attribute "batching"
dtype=options.dtype,
normalize_audio=options.normalize_audio,
batch_shuffle_window=options.batch_shuffle_window,
num_accumulate=options.num_accumulate,
num_prefetch=options.num_prefetch,
example_shuffle_window=options.example_shuffle_window,
use_fbank=False,
no_padding=True,
npc=10,
seed=options.seed,
extras=options.extras,
)
for ds in self._datasets:
if isinstance(ds, AsrDataset):
pipeline = ds.create_pipeline(
split, tokenizer, gang, min_audio_len, max_audio_len, options
)
elif isinstance(ds, SpeechDataset):
pipeline = ds.create_pipeline(
split, gang, min_audio_len, max_audio_len, speech_options
)

# Add name
# pipeline.map(partial(add_name, ds._name))

# Repeat
pipeline = pipeline.repeat().and_return()
pipelines.append(pipeline)

# Mix. Use a different seed in each rank, to make sure batches contain a
# variety of data sources.
mixing_pipeline = DataPipeline.sample(
pipelines, weights=self._sampling_weights, seed=options.seed + gang.rank
).and_return()

# Return reader
return DataPipelineReader[Seq2SeqBatch](
"mixing_asr_dataset", split, mixing_pipeline, gang, options
)


get_asr_dataset_hub = DatasetHubAccessor(AsrDataset)
Loading
Loading