diff --git a/src/fairseq2/recipes/asr/_common.py b/src/fairseq2/recipes/asr/_common.py index fd2fac086..fa5a537ec 100644 --- a/src/fairseq2/recipes/asr/_common.py +++ b/src/fairseq2/recipes/asr/_common.py @@ -7,11 +7,9 @@ from __future__ import annotations import math -from typing import Any, TextIO, final +from typing import Any, Dict, final, TextIO import torch -from torch import Tensor -from typing_extensions import override from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer from fairseq2.gang import Gang @@ -20,7 +18,10 @@ from fairseq2.models.asr import AsrModel, AsrModelOutput from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel from fairseq2.recipes import BaseMetricBag, Model, UnitError +from torch import Tensor +from typing_extensions import override @final @@ -41,22 +42,29 @@ def __init__(self, model: Model, scorer: AsrScorer | None = None) -> None: def __call__( self, batch: Seq2SeqBatch, metric_bag: AsrMetricBag ) -> tuple[Tensor, int]: - input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) + if isinstance(self._model.module, Wav2Vec2AsrModel): + input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) + else: + input_batch = batch output = self._forward(input_batch) - loss = output.compute_loss(batch.target_seqs, batch.target_padding_mask) + loss, extra_metrics = output.compute_loss( + batch.target_seqs, batch.target_padding_mask + ) metric_bag.update_ctc_loss(batch, loss) metric_bag.update_batch_metrics(batch) + metric_bag.update_extra_metrics(batch, extra_metrics) + if self._scorer is not None: self._scorer(batch, output, metric_bag) return loss, batch.batch_size - def _forward(self, batch: SequenceBatch) -> AsrModelOutput: + def _forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput: return self._model.module(batch) # type: ignore[no-any-return] @property @@ -150,11 +158,11 @@ class AsrMetricBag(BaseMetricBag): def __init__(self, gang: Gang, train: bool = True) -> None: super().__init__(gang, train=train) - d = gang.device + self.device = gang.device - self.register_metric("ctc_loss", Mean(device=d), persistent=False) + self.register_metric("ctc_loss", Mean(device=self.device), persistent=False) - self.register_metric("wer", WerMetric(device=d), persistent=False) + self.register_metric("wer", WerMetric(device=self.device), persistent=False) @torch.inference_mode() def update_ctc_loss(self, batch: Seq2SeqBatch, loss: Tensor) -> None: @@ -178,6 +186,16 @@ def update_batch_metrics(self, batch: Seq2SeqBatch) -> None: self.total_num_examples.update(num_examples) self.total_num_elements.update(num_elements) + @torch.inference_mode() + def update_extra_metrics( + self, batch: Seq2SeqBatch, extra_metrics: Dict[str, Tensor] + ) -> None: + n = batch.batch_size + for k in extra_metrics: + if k not in self.metrics: + self.register_metric(k, Mean(device=self.device), persistent=False) + self.metrics[k].update(extra_metrics[k].detach() / n, weight=n) + @override def process_metric_values(self, values: dict[str, Any]) -> None: super().process_metric_values(values) diff --git a/src/fairseq2/recipes/wav2vec2/asr/_train.py b/src/fairseq2/recipes/wav2vec2/asr/_train.py index 33bb6e535..1fa64e57e 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/_train.py +++ b/src/fairseq2/recipes/wav2vec2/asr/_train.py @@ -8,18 +8,20 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import cast, final, Literal +from typing import Literal, cast, final import torch +from torch import Tensor +from typing_extensions import override from fairseq2.context import RuntimeContext from fairseq2.datasets import LengthBatching, SyncMode -from fairseq2.datasets.asr import AsrDataset, AsrReadOptions, GENERIC_ASR_DATASET_FAMILY +from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, AsrDataset, AsrReadOptions from fairseq2.gang import Gang, GangError from fairseq2.logging import log +from fairseq2.models.asr import AsrModel from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.wav2vec2 import Wav2Vec2Model -from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel from fairseq2.nn.utils.module import freeze_parameters, share_parameters, to_device from fairseq2.optim import ADAMW_OPTIMIZER, AdamWConfig from fairseq2.optim.lr_scheduler import TRI_STAGE_LR, TriStageLRConfig @@ -54,15 +56,13 @@ ) from fairseq2.recipes.utils.log import log_model from fairseq2.recipes.wav2vec2.batch_weighted_datareader import ( - BatchMixtureDataset, MIXTURE_DATASET_FAMILY, + BatchMixtureDataset, ) from fairseq2.typing import CPU from fairseq2.utils.rng import manual_seed from fairseq2.utils.structured import structure from fairseq2.utils.validation import validate -from torch import Tensor -from typing_extensions import override @dataclass(kw_only=True) @@ -237,7 +237,7 @@ def load_wav2vec2_asr_trainer( seed += 1 model = load_base_model( - Wav2Vec2AsrModel, + AsrModel, context, config.model, config.trainer, @@ -246,7 +246,7 @@ def load_wav2vec2_asr_trainer( checkpoint_manager, ) - module = cast(Wav2Vec2AsrModel, model.module) + module = cast(AsrModel, model.module) # If we start the training with an empty ASR model, use the weights of a # pretrained wav2vec 2.0 model. @@ -416,7 +416,7 @@ def load_wav2vec2_asr_trainer( @final class Wav2Vec2AsrTrainUnit(TrainUnit[Seq2SeqBatch]): - _module: Wav2Vec2AsrModel + _module: AsrModel _criterion: AsrCriterion _freeze_encoder_for_n_steps: int _frozen: bool @@ -431,9 +431,9 @@ def __init__( """ module = criterion.model.base_module - if not isinstance(module, Wav2Vec2AsrModel): + if not isinstance(module, AsrModel): raise TypeError( - f"`criterion.model.base_module` must be of type `{Wav2Vec2AsrModel}`, but is of type `{type(module)}` instead." + f"`criterion.model.base_module` must be of type `{AsrModel}`, but is of type `{type(module)}` instead." ) self._module = module