-
Notifications
You must be signed in to change notification settings - Fork 115
Upleveling RNNT to new fs2 #1140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,9 @@ | |
from __future__ import annotations | ||
|
||
import math | ||
from typing import Any, TextIO, final | ||
from typing import Any, Dict, final, TextIO | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing_extensions import override | ||
|
||
from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer | ||
from fairseq2.gang import Gang | ||
|
@@ -20,7 +18,10 @@ | |
from fairseq2.models.asr import AsrModel, AsrModelOutput | ||
from fairseq2.models.seq2seq import Seq2SeqBatch | ||
from fairseq2.models.sequence import SequenceBatch | ||
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel | ||
from fairseq2.recipes import BaseMetricBag, Model, UnitError | ||
from torch import Tensor | ||
from typing_extensions import override | ||
|
||
|
||
@final | ||
|
@@ -41,22 +42,29 @@ def __init__(self, model: Model, scorer: AsrScorer | None = None) -> None: | |
def __call__( | ||
self, batch: Seq2SeqBatch, metric_bag: AsrMetricBag | ||
) -> tuple[Tensor, int]: | ||
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) | ||
if isinstance(self._model.module, Wav2Vec2AsrModel): | ||
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when is this codepath supposed to be used? the RNNT model doesn't use |
||
input_batch = batch | ||
|
||
output = self._forward(input_batch) | ||
|
||
loss = output.compute_loss(batch.target_seqs, batch.target_padding_mask) | ||
loss, extra_metrics = output.compute_loss( | ||
batch.target_seqs, batch.target_padding_mask | ||
) | ||
|
||
metric_bag.update_ctc_loss(batch, loss) | ||
|
||
metric_bag.update_batch_metrics(batch) | ||
|
||
metric_bag.update_extra_metrics(batch, extra_metrics) | ||
|
||
if self._scorer is not None: | ||
self._scorer(batch, output, metric_bag) | ||
|
||
return loss, batch.batch_size | ||
|
||
def _forward(self, batch: SequenceBatch) -> AsrModelOutput: | ||
def _forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput: | ||
return self._model.module(batch) # type: ignore[no-any-return] | ||
|
||
@property | ||
|
@@ -150,11 +158,11 @@ class AsrMetricBag(BaseMetricBag): | |
def __init__(self, gang: Gang, train: bool = True) -> None: | ||
super().__init__(gang, train=train) | ||
|
||
d = gang.device | ||
self.device = gang.device | ||
|
||
self.register_metric("ctc_loss", Mean(device=d), persistent=False) | ||
self.register_metric("ctc_loss", Mean(device=self.device), persistent=False) | ||
|
||
self.register_metric("wer", WerMetric(device=d), persistent=False) | ||
self.register_metric("wer", WerMetric(device=self.device), persistent=False) | ||
|
||
@torch.inference_mode() | ||
def update_ctc_loss(self, batch: Seq2SeqBatch, loss: Tensor) -> None: | ||
|
@@ -178,6 +186,16 @@ def update_batch_metrics(self, batch: Seq2SeqBatch) -> None: | |
self.total_num_examples.update(num_examples) | ||
self.total_num_elements.update(num_elements) | ||
|
||
@torch.inference_mode() | ||
def update_extra_metrics( | ||
self, batch: Seq2SeqBatch, extra_metrics: Dict[str, Tensor] | ||
) -> None: | ||
n = batch.batch_size | ||
for k in extra_metrics: | ||
if k not in self.metrics: | ||
self.register_metric(k, Mean(device=self.device), persistent=False) | ||
self.metrics[k].update(extra_metrics[k].detach() / n, weight=n) | ||
|
||
@override | ||
def process_metric_values(self, values: dict[str, Any]) -> None: | ||
super().process_metric_values(values) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,18 +8,20 @@ | |
|
||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import cast, final, Literal | ||
from typing import Literal, cast, final | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing_extensions import override | ||
|
||
from fairseq2.context import RuntimeContext | ||
from fairseq2.datasets import LengthBatching, SyncMode | ||
from fairseq2.datasets.asr import AsrDataset, AsrReadOptions, GENERIC_ASR_DATASET_FAMILY | ||
from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, AsrDataset, AsrReadOptions | ||
from fairseq2.gang import Gang, GangError | ||
from fairseq2.logging import log | ||
from fairseq2.models.asr import AsrModel | ||
from fairseq2.models.seq2seq import Seq2SeqBatch | ||
from fairseq2.models.wav2vec2 import Wav2Vec2Model | ||
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel | ||
from fairseq2.nn.utils.module import freeze_parameters, share_parameters, to_device | ||
from fairseq2.optim import ADAMW_OPTIMIZER, AdamWConfig | ||
from fairseq2.optim.lr_scheduler import TRI_STAGE_LR, TriStageLRConfig | ||
|
@@ -54,15 +56,13 @@ | |
) | ||
from fairseq2.recipes.utils.log import log_model | ||
from fairseq2.recipes.wav2vec2.batch_weighted_datareader import ( | ||
BatchMixtureDataset, | ||
MIXTURE_DATASET_FAMILY, | ||
BatchMixtureDataset, | ||
) | ||
from fairseq2.typing import CPU | ||
from fairseq2.utils.rng import manual_seed | ||
from fairseq2.utils.structured import structure | ||
from fairseq2.utils.validation import validate | ||
from torch import Tensor | ||
from typing_extensions import override | ||
|
||
|
||
@dataclass(kw_only=True) | ||
|
@@ -237,7 +237,7 @@ def load_wav2vec2_asr_trainer( | |
seed += 1 | ||
|
||
model = load_base_model( | ||
Wav2Vec2AsrModel, | ||
AsrModel, | ||
context, | ||
config.model, | ||
config.trainer, | ||
|
@@ -246,7 +246,7 @@ def load_wav2vec2_asr_trainer( | |
checkpoint_manager, | ||
) | ||
|
||
module = cast(Wav2Vec2AsrModel, model.module) | ||
module = cast(AsrModel, model.module) | ||
|
||
# If we start the training with an empty ASR model, use the weights of a | ||
# pretrained wav2vec 2.0 model. | ||
|
@@ -416,7 +416,7 @@ def load_wav2vec2_asr_trainer( | |
|
||
@final | ||
class Wav2Vec2AsrTrainUnit(TrainUnit[Seq2SeqBatch]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should this be called |
||
_module: Wav2Vec2AsrModel | ||
_module: AsrModel | ||
_criterion: AsrCriterion | ||
_freeze_encoder_for_n_steps: int | ||
_frozen: bool | ||
|
@@ -431,9 +431,9 @@ def __init__( | |
""" | ||
module = criterion.model.base_module | ||
|
||
if not isinstance(module, Wav2Vec2AsrModel): | ||
if not isinstance(module, AsrModel): | ||
raise TypeError( | ||
f"`criterion.model.base_module` must be of type `{Wav2Vec2AsrModel}`, but is of type `{type(module)}` instead." | ||
f"`criterion.model.base_module` must be of type `{AsrModel}`, but is of type `{type(module)}` instead." | ||
) | ||
|
||
self._module = module | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double checking ... does this type check work as expected? i remember a while back we chatted on this and had to introduce a hack. i think i got an error when doing CTC training on top of this code change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I just fixed it yesterday, following up on our old chat