Skip to content

Commit acd77f8

Browse files
committed
Debug commit upleveled latest
1 parent 6662153 commit acd77f8

File tree

5 files changed

+82
-37
lines changed

5 files changed

+82
-37
lines changed

src/fairseq2/cli/_main.py

+27-23
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,47 @@
88

99
import os
1010
import sys
11-
from signal import SIG_DFL, SIGINT, raise_signal, signal
11+
from signal import raise_signal, SIG_DFL, SIGINT, signal
1212

1313
import torch
14-
from torch.cuda import OutOfMemoryError
1514

1615
from fairseq2 import setup_fairseq2
17-
from fairseq2.cli.utils.rich import create_rich_progress_reporter
18-
from fairseq2.error import ContractError, InternalError
19-
from fairseq2.extensions import ExtensionError
20-
from fairseq2.logging import LoggingSetupError, log
21-
from fairseq2.setup import SetupError
22-
from fairseq2.utils.env import InvalidEnvironmentVariableError, get_rank
2316

2417
# isort: split
2518

2619
from fairseq2.cli._logging import setup_logging
2720
from fairseq2.cli._setup import setup_cli
21+
from fairseq2.cli.utils.rich import create_rich_progress_reporter
22+
from fairseq2.error import ContractError, InternalError
23+
from fairseq2.extensions import ExtensionError
24+
from fairseq2.logging import log, LoggingSetupError
25+
from fairseq2.setup import SetupError
26+
from fairseq2.utils.env import get_rank, InvalidEnvironmentVariableError
27+
from torch.cuda import OutOfMemoryError
2828

2929

3030
def main() -> None:
3131
"""Runs the command line fairseq2 program."""
3232
exit_code = 1
3333

34-
try:
35-
exit_code = _run()
36-
except KeyboardInterrupt:
37-
log.info("Command canceled!")
34+
# try:
35+
exit_code = _run()
36+
# except KeyboardInterrupt:
37+
# log.info("Command canceled!")
3838

39-
signal(SIGINT, SIG_DFL)
39+
# signal(SIGINT, SIG_DFL)
4040

41-
raise_signal(SIGINT)
42-
except OutOfMemoryError:
43-
s = torch.cuda.memory_summary()
41+
# raise_signal(SIGINT)
42+
# except OutOfMemoryError:
43+
# s = torch.cuda.memory_summary()
4444

45-
log.exception("CUDA out of memory. See logged memory stats.\n{}", s)
46-
except InternalError:
47-
log.exception("Command failed with an unexpected internal error. Please file a bug report.") # fmt: skip
48-
except ContractError:
49-
log.exception("Command failed with an unexpected internal error caused by an extension. Please file a bug report to the corresponding extension author.") # fmt: skip
50-
except Exception:
51-
log.exception("Command failed with an unexpected error. See the logged stack trace for details.") # fmt: skip
45+
# log.exception("CUDA out of memory. See logged memory stats.\n{}", s)
46+
# except InternalError:
47+
# log.exception("Command failed with an unexpected internal error. Please file a bug report.") # fmt: skip
48+
# except ContractError:
49+
# log.exception("Command failed with an unexpected internal error caused by an extension. Please file a bug report to the corresponding extension author.") # fmt: skip
50+
# except Exception:
51+
# log.exception("Command failed with an unexpected error. See the logged stack trace for details.") # fmt: skip
5252

5353
sys.exit(exit_code)
5454

@@ -84,3 +84,7 @@ def _run() -> int:
8484
return 1
8585

8686
return cli.run(context)
87+
88+
89+
if __name__ == "__main__":
90+
main()

src/fairseq2/nn/utils/module.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
from typing import Protocol, runtime_checkable
1414

1515
import torch
16-
from torch import Tensor
17-
from torch.nn import Module, Parameter
18-
from torch.nn.utils import remove_weight_norm # type: ignore[attr-defined]
1916

2017
from fairseq2.gang import Gang
2118
from fairseq2.logging import log
2219
from fairseq2.typing import CPU, Device
20+
from torch import Tensor
21+
from torch.nn import Module, Parameter
22+
from torch.nn.utils import remove_weight_norm # type: ignore[attr-defined]
2323

2424

2525
@runtime_checkable
@@ -464,6 +464,21 @@ def load_state_dict(
464464
``state_dict`` does not contain any keys corresponding to descendants that are set to ``None``
465465
via :meth:`Module.register_module()`.
466466
"""
467+
# Key mapping
468+
need_mapping = False
469+
sample_key = list(state_dict.keys())[0]
470+
if (
471+
sample_key.startswith("module.")
472+
and not sample_key in module.state_dict().keys()
473+
):
474+
mapped_key = sample_key[7:]
475+
if mapped_key in module.state_dict().keys():
476+
need_mapping = True
477+
478+
if need_mapping:
479+
key_mapping = lambda key: key[7:] if key.startswith("module.") else key
480+
state_dict = {key_mapping(key): value for key, value in state_dict.items()}
481+
467482
module.load_state_dict(state_dict, strict=strict)
468483

469484
unexpected_keys = []

src/fairseq2/recipes/_validator.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66

77
from __future__ import annotations
88

9+
import socket
10+
911
from abc import ABC, abstractmethod
1012
from collections.abc import Sequence
1113
from contextlib import nullcontext
12-
from typing import Generic, TypeVar, final
14+
from typing import final, Generic, TypeVar
1315

1416
import torch
15-
from torch import Tensor
16-
from torch.profiler import record_function
17-
from typing_extensions import override
1817

1918
from fairseq2.checkpoint import CheckpointError, CheckpointManager, CheckpointSaveError
2019
from fairseq2.datasets import DataReader, DataReadError
@@ -25,17 +24,20 @@
2524
from fairseq2.metrics import MetricBagError, MetricDescriptor
2625
from fairseq2.metrics.recorders import MetricRecorder, MetricRecordError
2726
from fairseq2.profilers import Profiler
28-
from fairseq2.typing import CPU, ContextManager, DataType
29-
from fairseq2.utils.device_stat import DeviceStatTracker
30-
from fairseq2.utils.progress import ProgressReporter, ProgressTask
31-
from fairseq2.utils.rng import RngBag
32-
from fairseq2.utils.stopwatch import Stopwatch
3327

3428
# isort: split
3529

3630
from fairseq2.recipes._error import RecipeError, UnitError
3731
from fairseq2.recipes._evaluator import EvalUnit
3832
from fairseq2.recipes._metrics import extend_batch_metrics
33+
from fairseq2.typing import ContextManager, CPU, DataType
34+
from fairseq2.utils.device_stat import DeviceStatTracker
35+
from fairseq2.utils.progress import ProgressReporter, ProgressTask
36+
from fairseq2.utils.rng import RngBag
37+
from fairseq2.utils.stopwatch import Stopwatch
38+
from torch import Tensor
39+
from torch.profiler import record_function
40+
from typing_extensions import override
3941

4042

4143
class Validator(ABC):
@@ -243,7 +245,14 @@ def _run_unit(
243245
f"The {s} unit has failed. See the nested exception for details."
244246
) from ex
245247

248+
machine_name = socket.gethostname()
249+
if machine_name.startswith("devvm"):
250+
_max_num_valid_steps = 5
251+
else:
252+
_max_num_valid_steps = 50000000000
253+
c = 0
246254
while not eod:
255+
log.info(f"s1: Running validation step {c}.")
247256
try:
248257
self._checkpoint_manager.maybe_complete_async_checkpoint()
249258
except CheckpointSaveError as ex:
@@ -252,10 +261,13 @@ def _run_unit(
252261
) from ex
253262

254263
batches = self._read_next_batches(unit, data_reader)
255-
if batches is None:
264+
log.info(f"s2: Read batches step {c}.")
265+
if batches is None or c == _max_num_valid_steps:
256266
eod = True
257267
else:
258268
self._run_step(unit, batches, progress_task)
269+
log.info(f"s7: Done with step {c}.")
270+
c += 1
259271

260272
with self._compute_watch:
261273
with record_function("finalize"):

src/fairseq2/recipes/asr/_common.py

+13
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
from __future__ import annotations
88

99
import math
10+
import re
1011
from typing import Any, Dict, final, TextIO
1112

1213
import torch
1314

1415
from fairseq2.data.text.tokenizers import TextTokenDecoder, TextTokenizer
1516
from fairseq2.gang import Gang
17+
18+
from fairseq2.logging import log
1619
from fairseq2.metrics import Mean
1720
from fairseq2.metrics.text import WerMetric
1821
from fairseq2.models.asr import AsrModel, AsrModelOutput
@@ -57,8 +60,10 @@ def __call__(
5760
)
5861
input_batch = batch
5962

63+
log.info(f"s3: calling forward")
6064
output = self._forward(input_batch)
6165

66+
log.info(f"s4: calling loss")
6267
loss, extra_metrics = output.compute_loss(
6368
batch.target_seqs, batch.target_padding_mask
6469
)
@@ -68,8 +73,11 @@ def __call__(
6873
metric_bag.update_batch_metrics(batch)
6974

7075
metric_bag.update_extra_metrics(batch, extra_metrics)
76+
77+
log.info(f"s5: calling scorer")
7178
if self._scorer is not None:
7279
self._scorer(batch, output, metric_bag)
80+
log.info(f"s6: done scorer")
7381

7482
return loss, batch.batch_size
7583

@@ -132,6 +140,11 @@ def __call__(
132140
refs = [self._text_decoder(s) for s in ref_seqs]
133141
hyps = [self._text_decoder(s) for s in hyp_seqs]
134142

143+
for r, h in zip(refs, hyps):
144+
if torch.rand([]) < 0.01 or bool(re.search(r"[\u0590-\u05FF]", r)):
145+
log.info(f"Reference: {r}")
146+
log.info(f"Hypothesis: {h}")
147+
135148
metric_bag.wer.update(
136149
refs, ref_seqs, ref_padding_mask, hyps, hyp_seqs, hyp_padding_mask
137150
)

src/fairseq2/recipes/wav2vec2/asr/_train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class Wav2Vec2AsrTrainConfig:
122122
validate_after_n_steps=10_000,
123123
validate_every_n_steps=1_000,
124124
publish_metrics_every_n_steps=200,
125+
keep_last_n_checkpoints=1,
125126
)
126127
)
127128

@@ -266,7 +267,7 @@ def load_wav2vec2_asr_trainer(
266267

267268
# If we start the training with an empty ASR model, use the weights of a
268269
# pretrained wav2vec 2.0 model.
269-
if model.is_empty_initialized:
270+
if model.is_empty_initialized and config.pretrained_model.name:
270271
pt_model = load_reference_model(
271272
Wav2Vec2Model,
272273
context,

0 commit comments

Comments
 (0)