Skip to content

Upleveling RNNT to new fs2 #1140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main_w2v2_pretraining
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/fairseq2/recipes/asr/_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -7,11 +7,9 @@
from __future__ import annotations

import math
from typing import Any, TextIO, final
from typing import Any, final, TextIO

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer
from fairseq2.gang import Gang
Expand All @@ -20,7 +18,11 @@
from fairseq2.models.asr import AsrModel, AsrModelOutput
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch

from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel
from fairseq2.recipes import BaseMetricBag, Model, UnitError
from torch import Tensor
from typing_extensions import override


@final
Expand All @@ -41,7 +43,10 @@
def __call__(
self, batch: Seq2SeqBatch, metric_bag: AsrMetricBag
) -> tuple[Tensor, int]:
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask)
if isinstance(self._model, Wav2Vec2AsrModel):
input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask)
else:
input_batch = batch

output = self._forward(input_batch)

Expand All @@ -56,7 +61,7 @@

return loss, batch.batch_size

def _forward(self, batch: SequenceBatch) -> AsrModelOutput:
def _forward(self, batch: SequenceBatch | Seq2SeqBatch) -> AsrModelOutput:
return self._model.module(batch) # type: ignore[no-any-return]

@property
Expand Down
12 changes: 6 additions & 6 deletions src/fairseq2/recipes/wav2vec2/asr/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, AsrDataset, AsrReadOptions
from fairseq2.gang import Gang, GangError
from fairseq2.logging import log
from fairseq2.models.asr import AsrModel
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.wav2vec2 import Wav2Vec2Model
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel
from fairseq2.nn.utils.module import freeze_parameters, share_parameters, to_device
from fairseq2.optim import ADAMW_OPTIMIZER, AdamWConfig
from fairseq2.optim.lr_scheduler import TRI_STAGE_LR, TriStageLRConfig
Expand Down Expand Up @@ -233,7 +233,7 @@
seed += 1

model = load_base_model(
Wav2Vec2AsrModel,
AsrModel,
context,
config.model,
config.trainer,
Expand All @@ -242,7 +242,7 @@
checkpoint_manager,
)

module = cast(Wav2Vec2AsrModel, model.module)
module = cast(AsrModel, model.module)

# If we start the training with an empty ASR model, use the weights of a
# pretrained wav2vec 2.0 model.
Expand All @@ -258,11 +258,11 @@

pt_module = cast(Wav2Vec2Model, pt_model.module)

share_parameters(pt_module.encoder_frontend, module.encoder_frontend)

Check failure on line 261 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 2 to "share_parameters" has incompatible type "Tensor | Module"; expected Module
share_parameters(pt_module.encoder, module.encoder)

Check failure on line 262 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 2 to "share_parameters" has incompatible type "Tensor | Module"; expected Module

if module.masker is not None:
share_parameters(pt_module.masker, module.masker)

Check failure on line 265 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 2 to "share_parameters" has incompatible type "Tensor | Module"; expected Module

del pt_model

Expand All @@ -279,7 +279,7 @@
) from ex

# We never train the feature extractor.
freeze_parameters(module.encoder_frontend.feature_extractor)

Check failure on line 282 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "freeze_parameters" has incompatible type "Any | Tensor | Module"; expected Module | None

Check failure on line 282 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Item "Tensor" of "Tensor | Module" has no attribute "feature_extractor"

prepare_model(context, config.trainer, model)

Expand Down Expand Up @@ -393,7 +393,7 @@

@final
class Wav2Vec2AsrTrainUnit(TrainUnit[Seq2SeqBatch]):
_module: Wav2Vec2AsrModel
_module: AsrModel
_criterion: AsrCriterion
_freeze_encoder_for_n_steps: int
_frozen: bool
Expand All @@ -408,9 +408,9 @@
"""
module = criterion.model.base_module

if not isinstance(module, Wav2Vec2AsrModel):
if not isinstance(module, AsrModel):
raise TypeError(
f"`criterion.model.base_module` must be of type `{Wav2Vec2AsrModel}`, but is of type `{type(module)}` instead."
f"`criterion.model.base_module` must be of type `{AsrModel}`, but is of type `{type(module)}` instead."
)

self._module = module
Expand All @@ -435,11 +435,11 @@
if step_nr == 1:
log.info("Freezing the encoder for the first {} steps.", self._freeze_encoder_for_n_steps) # fmt: skip

freeze_parameters(module.encoder_frontend)

Check failure on line 438 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "freeze_parameters" has incompatible type "Tensor | Module"; expected Module | None
freeze_parameters(module.encoder)

Check failure on line 439 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "freeze_parameters" has incompatible type "Tensor | Module"; expected Module | None

if module.masker is not None:
freeze_parameters(module.masker)

Check failure on line 442 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "freeze_parameters" has incompatible type "Tensor | Module"; expected Module | None

self._frozen = True
else:
Expand All @@ -452,7 +452,7 @@
freeze_parameters(module, False)

# We never train the feature extractor.
freeze_parameters(module.encoder_frontend.feature_extractor)

Check failure on line 455 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "freeze_parameters" has incompatible type "Any | Tensor | Module"; expected Module | None

Check failure on line 455 in src/fairseq2/recipes/wav2vec2/asr/_train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Item "Tensor" of "Tensor | Module" has no attribute "feature_extractor"

self._frozen = False

Expand Down
Loading