Skip to content

Set is_binarized to false for valid dataloaders. #1149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main_w2v2_pretraining
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/fairseq2/checkpoint/_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -13,14 +13,12 @@
from os import scandir
from pathlib import Path
from shutil import Error
from typing import Protocol, TypeAlias, cast, final
from typing import cast, final, Protocol, TypeAlias

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.error import InternalError
from fairseq2.gang import Gang, GangError, Gangs, all_sum
from fairseq2.gang import all_sum, Gang, GangError, Gangs
from fairseq2.nn.data_parallel import load_with_sdp_gang
from fairseq2.typing import CPU
from fairseq2.utils.file import (
Expand All @@ -33,6 +31,8 @@
)
from fairseq2.utils.state import Stateful
from fairseq2.utils.threading import ThreadPool
from torch import Tensor
from typing_extensions import override


class CheckpointManager(ABC):
Expand Down Expand Up @@ -872,7 +872,7 @@
if keep_last_n is None and keep_best_n is None and keep_every_n_steps is None:
return False

step_numbers = self.get_step_numbers()
step_numbers = self.get_step_numbers(exclude_model_only=True)
if not step_numbers:
return False

Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/recipes/wav2vec2/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def load_wav2vec2_trainer(
# Initialize the validation unit.
if config.dataset.valid_split is not None:

config.dataset.extras["is_binarized"] = False
read_options = SpeechReadOptions(
batching=batching,
dtype=config.trainer.dtype,
Expand Down
Loading