From fba9ae050287f30d36075ec79e8a3806c655d2b8 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Sun, 1 May 2022 23:20:00 +0800
Subject: [PATCH 01/11] First upload of model average codes.
---
.../pruned_transducer_stateless3/__init__.py | 1 +
.../asr_datamodule.py | 1 +
.../beam_search.py | 1 +
.../pruned_transducer_stateless3/conformer.py | 1 +
.../pruned_transducer_stateless3/decode.py | 597 ++++++++++
.../pruned_transducer_stateless3/decoder.py | 1 +
.../encoder_interface.py | 1 +
.../pruned_transducer_stateless3/export.py | 1 +
.../pruned_transducer_stateless3/joiner.py | 1 +
.../ASR/pruned_transducer_stateless3/model.py | 1 +
.../ASR/pruned_transducer_stateless3/optim.py | 1 +
.../pruned_transducer_stateless3/scaling.py | 1 +
.../ASR/pruned_transducer_stateless3/train.py | 1048 +++++++++++++++++
icefall/checkpoint.py | 118 ++
14 files changed, 1774 insertions(+)
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/beam_search.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/conformer.py
create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/decoder.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/encoder_interface.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/export.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/joiner.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/model.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/optim.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/scaling.py
create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/train.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py
new file mode 120000
index 0000000000..b24e5e3572
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/__init__.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
new file mode 120000
index 0000000000..a074d60850
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless3/beam_search.py
new file mode 120000
index 0000000000..8554e44ccf
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless3/conformer.py
new file mode 120000
index 0000000000..3b84b95739
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/conformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/conformer.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
new file mode 100755
index 0000000000..016393215b
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -0,0 +1,597 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 100 \
+ --decoding-method greedy_search
+
+(2) beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 100 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+ load_checkpoint_with_averaged_model,
+)
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=False,
+ help="Whether to load averaged model",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 100
+ else:
+ log_interval = 2
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += (
+ f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ )
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ assert params.iter == 0
+ if params.avg == 1:
+ filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ load_checkpoint_with_averaged_model(filename, model)
+ else:
+ assert params.avg > 1
+ start = params.epoch - params.avg
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(f"averaging {filename_start} and {filename_end}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+ test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decoder.py
new file mode 120000
index 0000000000..0793c5709c
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decoder.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless3/encoder_interface.py
new file mode 120000
index 0000000000..b9aa0ae083
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
new file mode 120000
index 0000000000..19c56a722d
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/export.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless3/joiner.py
new file mode 120000
index 0000000000..815fd4bb6f
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/joiner.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py
new file mode 120000
index 0000000000..ebb6d774d9
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless3/optim.py
new file mode 120000
index 0000000000..e2deb44925
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling.py
new file mode 120000
index 0000000000..09d802cc44
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
new file mode 100755
index 0000000000..83d31c64e5
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -0,0 +1,1048 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --full-libri 1 \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --full-libri 1 \
+ --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import save_checkpoint_with_global_batch_idx
+from icefall.checkpoint import update_averaged_model
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ transducer_stateless2/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="""The initial learning rate. This value should not need to be
+ changed.""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate decreases.
+ We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=20,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "encoder_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ # parameters for decoder
+ "decoder_dim": 512,
+ # parameters for joiner
+ "joiner_dim": 512,
+ # parameters for Noam
+ "model_warm_step": 3000, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 0:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ if params.full_libri is False:
+ params.valid_interval = 1600
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: nn.Module = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+ model.device = device
+
+ if rank == 0:
+ model_avg.to(device)
+ model_avg.device = device
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2 ** 22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = librispeech.train_clean_100_cuts()
+ if params.full_libri:
+ train_cuts += librispeech.train_clean_360_cuts()
+ train_cuts += librispeech.train_other_500_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 20.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ scheduler.step_epoch(epoch)
+ fix_random_seed(params.seed + epoch)
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index a4e71a148f..af8c1701db 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -25,6 +25,7 @@
import torch
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
+from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@@ -37,6 +38,7 @@
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@@ -51,6 +53,8 @@ def save_checkpoint(
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
+ model_avg:
+ The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer:
@@ -80,6 +84,9 @@ def save_checkpoint(
"sampler": sampler.state_dict() if sampler is not None else None,
}
+ if model_avg is not None:
+ checkpoint["model_avg"] = model_avg.state_dict()
+
if params:
for k, v in params.items():
assert k not in checkpoint
@@ -91,6 +98,7 @@ def save_checkpoint(
def load_checkpoint(
filename: Path,
model: nn.Module,
+ model_avg: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
@@ -118,6 +126,10 @@ def load_checkpoint(
checkpoint.pop("model")
+ if model_avg is not None and "model_avg" in checkpoint:
+ model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
+ checkpoint.pop("model_avg")
+
def load(name, obj):
s = checkpoint.get(name, None)
if obj and s:
@@ -181,6 +193,7 @@ def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@@ -201,6 +214,8 @@ def save_checkpoint_with_global_batch_idx(
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
+ model_avg:
+ The stored model averaged from the start of training.
params:
A dict of training configurations to be saved.
optimizer:
@@ -223,6 +238,7 @@ def save_checkpoint_with_global_batch_idx(
save_checkpoint(
filename=filename,
model=model,
+ model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
@@ -327,3 +343,105 @@ def remove_checkpoints(
to_remove = checkpoints[topk:]
for c in to_remove:
os.remove(c)
+
+
+def update_averaged_model(
+ params: Dict[str, Tensor],
+ model_cur: Union[nn.Module, DDP],
+ model_avg: nn.Module,
+) -> None:
+ """Update the averaged model,
+
+ Args:
+ params:
+ User defined parameters, e.g., epoch, loss.
+ model_cur:
+ The current model.
+ model_avg:
+ The stored model averaged from start of training to update.
+ """
+ weight_cur = params.average_period / params.batch_idx_train
+ weight_avg = 1 - weight_cur
+
+ if isinstance(model_cur, DDP):
+ model_cur = model_cur.module
+
+ cur = model_cur.state_dict()
+ avg = model_avg.state_dict()
+
+ uniqued: Dict[int, str] = dict()
+ for k, v in avg.items():
+ v_data_ptr = v.data_ptr()
+ if v_data_ptr in uniqued:
+ continue
+ uniqued[v_data_ptr] = k
+
+ uniqued_names = list(uniqued.values())
+ for k in uniqued_names:
+ avg[k] *= weight_avg
+ avg[k] += cur[k] * weight_cur
+
+
+def average_checkpoints_with_averaged_model(
+ filename_start: str,
+ filename_end: str,
+ device: torch.device = torch.device("cpu"),
+) -> Dict[str, Tensor]:
+ """Average model parameters over the range with given
+ start model(excluded) and end model.
+
+ Let start = batch_idx_train of model-start,
+ end = batch_idx_train of model-end,
+ Then the average model over epoch [start+1, start+2, ..., end] is
+ avg = (model_end * end - model_start * start) / (start - end)
+
+ The model index could be epoch number or checkpoint number.
+
+ Args:
+ filename_start:
+ Checkpoint filename of the start model. We assume it
+ is saved by :func:`save_checkpoint`.
+ filename_end:
+ Checkpoint filename of the end model. We assume it
+ is saved by :func:`save_checkpoint`.
+ device:
+ Move checkpoints to this device before averaging.
+ """
+ state_dict_start = torch.load(filename_start, map_location=device)
+ state_dict_end = torch.load(filename_end, map_location=device)
+
+ batch_idx_train_start = state_dict_start["batch_idx_train"]
+ batch_idx_train_end = state_dict_end["batch_idx_train"]
+ interval = batch_idx_train_end - batch_idx_train_start
+ weight_start = -batch_idx_train_start / interval
+ weight_end = batch_idx_train_end / interval
+
+ avg = state_dict_end["model_avg"]
+ model_start = state_dict_start["model_avg"]
+
+ # Identify shared parameters. Two parameters are said to be shared
+ # if they have the same data_ptr
+ uniqued: Dict[int, str] = dict()
+ for k, v in avg.items():
+ v_data_ptr = v.data_ptr()
+ if v_data_ptr in uniqued:
+ continue
+ uniqued[v_data_ptr] = k
+
+ uniqued_names = list(uniqued.values())
+ for k in uniqued_names:
+ avg[k] *= weight_end
+ avg[k] += model_start[k] * weight_start
+
+ return avg
+
+
+def load_checkpoint_with_averaged_model(
+ filename: str,
+ model: nn.Module,
+ strict: bool = True,
+) -> None:
+ """Load checkpoint with aaveraged model."""
+ logging.info(f"Loading checkpoint from {filename}, using averaged model")
+ checkpoint = torch.load(filename, map_location="cpu")
+ model.load_state_dict(checkpoint["model_avg"], strict=strict)
From 08b37e07a4bf10d2098d1055acf4a3304e437b75 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Mon, 2 May 2022 00:50:32 +0800
Subject: [PATCH 02/11] minor fix
---
.../ASR/pruned_transducer_stateless3/decode.py | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 016393215b..34125e9d62 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -81,7 +81,6 @@
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
- load_checkpoint_with_averaged_model,
)
from icefall.utils import (
AttributeDict,
@@ -481,6 +480,9 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@@ -534,15 +536,14 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
- if params.avg == 1:
- filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
- load_checkpoint_with_averaged_model(filename, model)
- else:
- assert params.avg > 1
+ if True:
start = params.epoch - params.avg
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
- logging.info(f"averaging {filename_start} and {filename_end}")
+ logging.info(
+ f"averaging modes over range with {filename_start} (excluded) "
+ f"and {filename_end}"
+ )
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
From aea8a03e009122f371f2c696c7baf228276ff076 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Mon, 2 May 2022 12:17:43 +0800
Subject: [PATCH 03/11] update decode file
---
.../pruned_transducer_stateless3/decode.py | 29 +++++++++----------
icefall/checkpoint.py | 14 ++-------
2 files changed, 16 insertions(+), 27 deletions(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 34125e9d62..a6fe0336c5 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -536,22 +536,21 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
- if True:
- start = params.epoch - params.avg
- filename_start = f"{params.exp_dir}/epoch-{start}.pt"
- filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
- logging.info(
- f"averaging modes over range with {filename_start} (excluded) "
- f"and {filename_end}"
- )
- model.to(device)
- model.load_state_dict(
- average_checkpoints_with_averaged_model(
- filename_start=filename_start,
- filename_end=filename_end,
- device=device,
- )
+ start = params.epoch - params.avg
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"averaging modes over range with {filename_start} (excluded) "
+ f"and {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
)
+ )
model.to(device)
model.eval()
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index af8c1701db..c7b09c8ac6 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -416,8 +416,9 @@ def average_checkpoints_with_averaged_model(
weight_start = -batch_idx_train_start / interval
weight_end = batch_idx_train_end / interval
- avg = state_dict_end["model_avg"]
+ model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"]
+ avg = model_end
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@@ -434,14 +435,3 @@ def average_checkpoints_with_averaged_model(
avg[k] += model_start[k] * weight_start
return avg
-
-
-def load_checkpoint_with_averaged_model(
- filename: str,
- model: nn.Module,
- strict: bool = True,
-) -> None:
- """Load checkpoint with aaveraged model."""
- logging.info(f"Loading checkpoint from {filename}, using averaged model")
- checkpoint = torch.load(filename, map_location="cpu")
- model.load_state_dict(checkpoint["model_avg"], strict=strict)
From 36c241e59feebe82936f89645ecadadb66d3a9f0 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Mon, 2 May 2022 12:22:24 +0800
Subject: [PATCH 04/11] update .flake8
---
.flake8 | 1 +
1 file changed, 1 insertion(+)
diff --git a/.flake8 b/.flake8
index cd55ded739..b1ed0a662f 100644
--- a/.flake8
+++ b/.flake8
@@ -9,6 +9,7 @@ per-file-ignores =
egs/tedlium3/ASR/*/conformer.py: E501,
egs/gigaspeech/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
+ egs/librispeech/ASR/pruned_transducer_stateless3/*.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605
From 44d75e22c98c4276a494e2b5e6329fc90c36e78e Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Wed, 4 May 2022 21:37:55 +0800
Subject: [PATCH 05/11] rename pruned_transducer_stateless3 to
pruned_transducer_stateless4
---
.../__init__.py | 0
.../asr_datamodule.py | 0
.../beam_search.py | 0
.../conformer.py | 0
.../decode.py | 0
.../decoder.py | 0
.../encoder_interface.py | 0
.../export.py | 0
.../joiner.py | 0
.../model.py | 0
.../optim.py | 0
.../scaling.py | 0
.../train.py | 6 ++++--
13 files changed, 4 insertions(+), 2 deletions(-)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/__init__.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/asr_datamodule.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/beam_search.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/conformer.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/decode.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/decoder.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/encoder_interface.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/export.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/joiner.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/model.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/optim.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/scaling.py (100%)
rename egs/librispeech/ASR/{pruned_transducer_stateless3 => pruned_transducer_stateless4}/train.py (99%)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless4/__init__.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/__init__.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless4/asr_datamodule.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/asr_datamodule.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless4/beam_search.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/beam_search.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/beam_search.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/conformer.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/decoder.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless4/encoder_interface.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/encoder_interface.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/encoder_interface.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/export.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/export.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/joiner.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/model.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/model.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/model.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/optim.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/optim.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/scaling.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
similarity index 99%
rename from egs/librispeech/ASR/pruned_transducer_stateless3/train.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index 83d31c64e5..33478c6301 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -74,8 +74,10 @@
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
-from icefall.checkpoint import save_checkpoint_with_global_batch_idx
-from icefall.checkpoint import update_averaged_model
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
From ff3c0d5d86dc0a63544756c461cdbd8ad497873a Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Thu, 5 May 2022 11:46:49 +0800
Subject: [PATCH 06/11] change epoch number counter starting from 1 instead of
0
---
.../ASR/pruned_transducer_stateless4/train.py | 28 ++++++++++---------
1 file changed, 15 insertions(+), 13 deletions(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index 33478c6301..a08ebad545 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -21,20 +21,20 @@
export CUDA_VISIBLE_DEVICES="0,1,2,3"
-./pruned_transducer_stateless2/train.py \
+./pruned_transducer_stateless4/train.py \
--world-size 4 \
--num-epochs 30 \
- --start-epoch 0 \
+ --start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
-./pruned_transducer_stateless2/train.py \
+./pruned_transducer_stateless4/train.py \
--world-size 4 \
--num-epochs 30 \
- --start-epoch 0 \
+ --start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
@@ -123,7 +123,7 @@ def get_parser():
parser.add_argument(
"--start-epoch",
type=int,
- default=0,
+ default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
@@ -418,7 +418,7 @@ def load_checkpoint_if_available(
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
- params.start_epoch is positive, it will load the checkpoint from
+ params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
@@ -430,6 +430,8 @@ def load_checkpoint_if_available(
The return value of :func:`get_params`.
model:
The training model.
+ model_avg:
+ The stored model averaged from the start of training.
optimizer:
The optimizer that we are using.
scheduler:
@@ -439,7 +441,7 @@ def load_checkpoint_if_available(
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
- elif params.start_epoch > 0:
+ elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
@@ -849,7 +851,7 @@ def run(rank, world_size, args):
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
- model_avg: nn.Module = None
+ model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)
@@ -939,10 +941,10 @@ def remove_short_and_long_utt(c: Cut):
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
- for epoch in range(params.start_epoch, params.num_epochs):
- scheduler.step_epoch(epoch)
- fix_random_seed(params.seed + epoch)
- train_dl.sampler.set_epoch(epoch)
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@@ -996,7 +998,7 @@ def scan_pessimistic_batches_for_oom(
from lhotse.dataset import find_pessimistic_batches
logging.info(
- "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
From a0592e0d0fee7241b731f9c21d2641936a417662 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Thu, 5 May 2022 19:10:06 +0800
Subject: [PATCH 07/11] minor fix of pruned_transducer_stateless4/train.py
---
egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index a08ebad545..568f41cef6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -126,7 +126,7 @@ def get_parser():
default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
- transducer_stateless2/exp/epoch-{start_epoch-1}.pt
+ exp-dir/epoch-{start_epoch-1}.pt
""",
)
From 8bf2fef1e0bf0ad7f9b275dd7769903ce1fd8254 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Thu, 5 May 2022 19:10:51 +0800
Subject: [PATCH 08/11] refactor the checkpoint.py
---
icefall/checkpoint.py | 69 ++++++++++++++++++++++++++++---------------
1 file changed, 45 insertions(+), 24 deletions(-)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index c7b09c8ac6..3ad346a1cd 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -127,6 +127,7 @@ def load_checkpoint(
checkpoint.pop("model")
if model_avg is not None and "model_avg" in checkpoint:
+ logging.info("Loading averaged model")
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
checkpoint.pop("model_avg")
@@ -350,7 +351,9 @@ def update_averaged_model(
model_cur: Union[nn.Module, DDP],
model_avg: nn.Module,
) -> None:
- """Update the averaged model,
+ """Update the averaged model:
+ model_avg = model_cur * (average_period / batch_idx_train)
+ + model_avg * ((batch_idx_train - average_period) / batch_idx_train)
Args:
params:
@@ -358,7 +361,7 @@ def update_averaged_model(
model_cur:
The current model.
model_avg:
- The stored model averaged from start of training to update.
+ The averaged model to be updated.
"""
weight_cur = params.average_period / params.batch_idx_train
weight_avg = 1 - weight_cur
@@ -369,17 +372,12 @@ def update_averaged_model(
cur = model_cur.state_dict()
avg = model_avg.state_dict()
- uniqued: Dict[int, str] = dict()
- for k, v in avg.items():
- v_data_ptr = v.data_ptr()
- if v_data_ptr in uniqued:
- continue
- uniqued[v_data_ptr] = k
-
- uniqued_names = list(uniqued.values())
- for k in uniqued_names:
- avg[k] *= weight_avg
- avg[k] += cur[k] * weight_cur
+ average_state_dict(
+ state_dict_1=avg,
+ state_dict_2=cur,
+ weight_1=weight_avg,
+ weight_2=weight_cur,
+ )
def average_checkpoints_with_averaged_model(
@@ -388,12 +386,12 @@ def average_checkpoints_with_averaged_model(
device: torch.device = torch.device("cpu"),
) -> Dict[str, Tensor]:
"""Average model parameters over the range with given
- start model(excluded) and end model.
+ start model (excluded) and end model.
- Let start = batch_idx_train of model-start,
- end = batch_idx_train of model-end,
- Then the average model over epoch [start+1, start+2, ..., end] is
- avg = (model_end * end - model_start * start) / (start - end)
+ Let start = batch_idx_train of model-start;
+ end = batch_idx_train of model-end.
+ Then the average model over range from start (excluded) to end is
+ avg = (model_end * end - model_start * start) / (start - end).
The model index could be epoch number or checkpoint number.
@@ -413,17 +411,41 @@ def average_checkpoints_with_averaged_model(
batch_idx_train_start = state_dict_start["batch_idx_train"]
batch_idx_train_end = state_dict_end["batch_idx_train"]
interval = batch_idx_train_end - batch_idx_train_start
- weight_start = -batch_idx_train_start / interval
weight_end = batch_idx_train_end / interval
+ weight_start = 1 - weight_end
model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"]
avg = model_end
+ # scale the weight to avoid overflow
+ average_state_dict(
+ state_dict_1=avg,
+ state_dict_2=model_start,
+ weight_1=1.0,
+ weight_2=weight_start / weight_end,
+ scaling_factor=weight_end,
+ )
+
+ return avg
+
+
+def average_state_dict(
+ state_dict_1: Dict[str, Tensor],
+ state_dict_2: Dict[str, Tensor],
+ weight_1: float,
+ weight_2: float,
+ scaling_factor: float = 1.0,
+) -> Dict[str, Tensor]:
+ """Average two state_dict with given weights:
+ state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 + weight_2)
+ * scaling_factor
+ It is an in-place operation on state_dict_1 itself.
+ """
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
- for k, v in avg.items():
+ for k, v in state_dict_1.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
@@ -431,7 +453,6 @@ def average_checkpoints_with_averaged_model(
uniqued_names = list(uniqued.values())
for k in uniqued_names:
- avg[k] *= weight_end
- avg[k] += model_start[k] * weight_start
-
- return avg
+ state_dict_1[k] *= weight_1
+ state_dict_1[k] += state_dict_2[k] * weight_2
+ state_dict_1[k] *= scaling_factor
From 22ecc567cbfc0d71ad3829246f7afd816571a4ba Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Thu, 5 May 2022 19:13:45 +0800
Subject: [PATCH 09/11] minor fix, update docs, and modify the epoch number to
count from 1 in the pruned_transducer_stateless4/decode.py
---
.../pruned_transducer_stateless4/decode.py | 37 +++++++++++--------
1 file changed, 21 insertions(+), 16 deletions(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index a6fe0336c5..e868878e67 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -18,16 +18,16 @@
"""
Usage:
(1) greedy search
-./pruned_transducer_stateless2/decode.py \
- --epoch 28 \
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
-./pruned_transducer_stateless2/decode.py \
- --epoch 28 \
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
@@ -35,8 +35,8 @@
--beam-size 4
(3) modified beam search
-./pruned_transducer_stateless2/decode.py \
- --epoch 28 \
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
@@ -44,8 +44,8 @@
--beam-size 4
(4) fast beam search
-./pruned_transducer_stateless2/decode.py \
- --epoch 28 \
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
@@ -99,9 +99,9 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
- default=28,
+ default=30,
help="""It specifies the checkpoint to use for decoding.
- Note: Epoch counts from 0.
+ Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@@ -128,13 +128,17 @@ def get_parser():
"--use-averaged-model",
type=str2bool,
default=False,
- help="Whether to load averaged model",
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
- default="pruned_transducer_stateless2/exp",
+ default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)
@@ -529,19 +533,20 @@ def main():
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
- if start >= 0:
+ if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
- assert params.iter == 0
+ assert params.iter == 0 and params.avg > 0
start = params.epoch - params.avg
+ assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
- f"averaging modes over range with {filename_start} (excluded) "
- f"and {filename_end}"
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
From 5c07402af8f1be6448ab5116a186556ca01cbb17 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Thu, 5 May 2022 19:44:19 +0800
Subject: [PATCH 10/11] update author info
---
egs/librispeech/ASR/pruned_transducer_stateless4/decode.py | 3 ++-
egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 5 +++--
icefall/checkpoint.py | 3 ++-
3 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index e868878e67..e066629052 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
-# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index 568f41cef6..147bcf658f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -1,7 +1,8 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
-# Wei Kang
-# Mingshuang Luo)
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 3ad346a1cd..77c47fc940 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -1,4 +1,5 @@
-# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
From 4f18d52c8ceea7432c9c6dcfe9a72c9e2387aba5 Mon Sep 17 00:00:00 2001
From: yaozengwei
Date: Thu, 5 May 2022 21:16:32 +0800
Subject: [PATCH 11/11] add docs of the scaling in function
average_checkpoints_with_averaged_model
---
icefall/checkpoint.py | 18 +++++++++++++++---
1 file changed, 15 insertions(+), 3 deletions(-)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 77c47fc940..5b562ccc87 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -390,9 +390,20 @@ def average_checkpoints_with_averaged_model(
start model (excluded) and end model.
Let start = batch_idx_train of model-start;
- end = batch_idx_train of model-end.
+ end = batch_idx_train of model-end;
+ interval = end - start.
Then the average model over range from start (excluded) to end is
- avg = (model_end * end - model_start * start) / (start - end).
+ (1) avg = (model_end * end - model_start * start) / interval.
+ It can be written as
+ (2) avg = model_end * weight_end + model_start * weight_start,
+ where weight_end = end / interval,
+ weight_start = -start / interval = 1 - weight_end.
+ Since the terms `weight_end` and `weight_start` would be large
+ if the model has been trained for lots of batches, which would cause
+ overflow when multiplying the model parameters.
+ To avoid this, we rewrite (2) as:
+ (3) avg = (model_end + model_start * (weight_start / weight_end))
+ * weight_end
The model index could be epoch number or checkpoint number.
@@ -412,6 +423,7 @@ def average_checkpoints_with_averaged_model(
batch_idx_train_start = state_dict_start["batch_idx_train"]
batch_idx_train_end = state_dict_end["batch_idx_train"]
interval = batch_idx_train_end - batch_idx_train_start
+ assert interval > 0, interval
weight_end = batch_idx_train_end / interval
weight_start = 1 - weight_end
@@ -439,7 +451,7 @@ def average_state_dict(
scaling_factor: float = 1.0,
) -> Dict[str, Tensor]:
"""Average two state_dict with given weights:
- state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 + weight_2)
+ state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
* scaling_factor
It is an in-place operation on state_dict_1 itself.
"""