Skip to content

Upleveling RNNT to new fs2 #1140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions src/fairseq2/recipes/asr/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from __future__ import annotations

import math
from typing import Any, TextIO, final
from typing import Any, Dict, final, TextIO

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer
from fairseq2.gang import Gang
Expand All @@ -20,7 +18,10 @@
from fairseq2.models.asr import AsrModel, AsrModelOutput
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel
from fairseq2.recipes import BaseMetricBag, Model, UnitError
from torch import Tensor
from typing_extensions import override


@final
Expand All @@ -41,22 +42,29 @@ def __init__(self, model: Model, scorer: AsrScorer | None = None) -> None:
def __call__(
self, batch: Seq2SeqBatch, metric_bag: AsrMetricBag
) -> tuple[Tensor, int]:
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask)
if isinstance(self._model.module, Wav2Vec2AsrModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double checking ... does this type check work as expected? i remember a while back we chatted on this and had to introduce a hack. i think i got an error when doing CTC training on top of this code change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I just fixed it yesterday, following up on our old chat

input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is this codepath supposed to be used? the RNNT model doesn't use Seq2SeqBatch ?

input_batch = batch

output = self._forward(input_batch)

loss = output.compute_loss(batch.target_seqs, batch.target_padding_mask)
loss, extra_metrics = output.compute_loss(
batch.target_seqs, batch.target_padding_mask
)

metric_bag.update_ctc_loss(batch, loss)

metric_bag.update_batch_metrics(batch)

metric_bag.update_extra_metrics(batch, extra_metrics)

if self._scorer is not None:
self._scorer(batch, output, metric_bag)

return loss, batch.batch_size

def _forward(self, batch: SequenceBatch) -> AsrModelOutput:
def _forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput:
return self._model.module(batch) # type: ignore[no-any-return]

@property
Expand Down Expand Up @@ -150,11 +158,11 @@ class AsrMetricBag(BaseMetricBag):
def __init__(self, gang: Gang, train: bool = True) -> None:
super().__init__(gang, train=train)

d = gang.device
self.device = gang.device

self.register_metric("ctc_loss", Mean(device=d), persistent=False)
self.register_metric("ctc_loss", Mean(device=self.device), persistent=False)

self.register_metric("wer", WerMetric(device=d), persistent=False)
self.register_metric("wer", WerMetric(device=self.device), persistent=False)

@torch.inference_mode()
def update_ctc_loss(self, batch: Seq2SeqBatch, loss: Tensor) -> None:
Expand All @@ -178,6 +186,16 @@ def update_batch_metrics(self, batch: Seq2SeqBatch) -> None:
self.total_num_examples.update(num_examples)
self.total_num_elements.update(num_elements)

@torch.inference_mode()
def update_extra_metrics(
self, batch: Seq2SeqBatch, extra_metrics: Dict[str, Tensor]
) -> None:
n = batch.batch_size
for k in extra_metrics:
if k not in self.metrics:
self.register_metric(k, Mean(device=self.device), persistent=False)
self.metrics[k].update(extra_metrics[k].detach() / n, weight=n)

@override
def process_metric_values(self, values: dict[str, Any]) -> None:
super().process_metric_values(values)
Expand Down
22 changes: 11 additions & 11 deletions src/fairseq2/recipes/wav2vec2/asr/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import cast, final, Literal
from typing import Literal, cast, final

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.datasets import LengthBatching, SyncMode
from fairseq2.datasets.asr import AsrDataset, AsrReadOptions, GENERIC_ASR_DATASET_FAMILY
from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, AsrDataset, AsrReadOptions
from fairseq2.gang import Gang, GangError
from fairseq2.logging import log
from fairseq2.models.asr import AsrModel
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.wav2vec2 import Wav2Vec2Model
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel
from fairseq2.nn.utils.module import freeze_parameters, share_parameters, to_device
from fairseq2.optim import ADAMW_OPTIMIZER, AdamWConfig
from fairseq2.optim.lr_scheduler import TRI_STAGE_LR, TriStageLRConfig
Expand Down Expand Up @@ -54,15 +56,13 @@
)
from fairseq2.recipes.utils.log import log_model
from fairseq2.recipes.wav2vec2.batch_weighted_datareader import (
BatchMixtureDataset,
MIXTURE_DATASET_FAMILY,
BatchMixtureDataset,
)
from fairseq2.typing import CPU
from fairseq2.utils.rng import manual_seed
from fairseq2.utils.structured import structure
from fairseq2.utils.validation import validate
from torch import Tensor
from typing_extensions import override


@dataclass(kw_only=True)
Expand Down Expand Up @@ -237,7 +237,7 @@ def load_wav2vec2_asr_trainer(
seed += 1

model = load_base_model(
Wav2Vec2AsrModel,
AsrModel,
context,
config.model,
config.trainer,
Expand All @@ -246,7 +246,7 @@ def load_wav2vec2_asr_trainer(
checkpoint_manager,
)

module = cast(Wav2Vec2AsrModel, model.module)
module = cast(AsrModel, model.module)

# If we start the training with an empty ASR model, use the weights of a
# pretrained wav2vec 2.0 model.
Expand Down Expand Up @@ -416,7 +416,7 @@ def load_wav2vec2_asr_trainer(

@final
class Wav2Vec2AsrTrainUnit(TrainUnit[Seq2SeqBatch]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this be called AsrTrainUnit now? not a biggie.

_module: Wav2Vec2AsrModel
_module: AsrModel
_criterion: AsrCriterion
_freeze_encoder_for_n_steps: int
_frozen: bool
Expand All @@ -431,9 +431,9 @@ def __init__(
"""
module = criterion.model.base_module

if not isinstance(module, Wav2Vec2AsrModel):
if not isinstance(module, AsrModel):
raise TypeError(
f"`criterion.model.base_module` must be of type `{Wav2Vec2AsrModel}`, but is of type `{type(module)}` instead."
f"`criterion.model.base_module` must be of type `{AsrModel}`, but is of type `{type(module)}` instead."
)

self._module = module
Expand Down