Skip to content

Commit 8635fb4

Browse files
authored
Fix decoding for gigaspeech in the libri + giga setup. (#345)
1 parent e1c3e98 commit 8635fb4

File tree

1 file changed

+104
-27
lines changed

1 file changed

+104
-27
lines changed

egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py

Lines changed: 104 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@
6969
from asr_datamodule import AsrDataModule
7070
from beam_search import (
7171
beam_search,
72-
fast_beam_search,
72+
fast_beam_search_nbest_oracle,
73+
fast_beam_search_one_best,
7374
greedy_search,
7475
greedy_search_batch,
7576
modified_beam_search,
@@ -100,27 +101,28 @@ def get_parser():
100101
"--epoch",
101102
type=int,
102103
default=28,
103-
help="It specifies the checkpoint to use for decoding."
104-
"Note: Epoch counts from 0.",
104+
help="""It specifies the checkpoint to use for decoding.
105+
Note: Epoch counts from 0.
106+
You can specify --avg to use more checkpoints for model averaging.""",
105107
)
108+
106109
parser.add_argument(
107-
"--avg",
110+
"--iter",
108111
type=int,
109-
default=15,
110-
help="Number of checkpoints to average. Automatically select "
111-
"consecutive checkpoints before the checkpoint specified by "
112-
"'--epoch'. ",
112+
default=0,
113+
help="""If positive, --epoch is ignored and it
114+
will use the checkpoint exp_dir/checkpoint-iter.pt.
115+
You can specify --avg to use more checkpoints for model averaging.
116+
""",
113117
)
114118

115119
parser.add_argument(
116-
"--avg-last-n",
120+
"--avg",
117121
type=int,
118-
default=0,
119-
help="""If positive, --epoch and --avg are ignored and it
120-
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
121-
where xxx is the number of processed batches while
122-
saving that checkpoint.
123-
""",
122+
default=15,
123+
help="Number of checkpoints to average. Automatically select "
124+
"consecutive checkpoints before the checkpoint specified by "
125+
"'--epoch' and '--iter'",
124126
)
125127

126128
parser.add_argument(
@@ -146,6 +148,7 @@ def get_parser():
146148
- beam_search
147149
- modified_beam_search
148150
- fast_beam_search
151+
- fast_beam_search_nbest_oracle
149152
""",
150153
)
151154

@@ -165,23 +168,24 @@ def get_parser():
165168
help="""A floating point value to calculate the cutoff score during beam
166169
search (i.e., `cutoff = max-score - beam`), which is the same as the
167170
`beam` in Kaldi.
168-
Used only when --decoding-method is fast_beam_search""",
171+
Used only when --decoding-method is
172+
fast_beam_search or fast_beam_search_nbest_oracle""",
169173
)
170174

171175
parser.add_argument(
172176
"--max-contexts",
173177
type=int,
174178
default=4,
175179
help="""Used only when --decoding-method is
176-
fast_beam_search""",
180+
fast_beam_search or fast_beam_search_nbest_oracle""",
177181
)
178182

179183
parser.add_argument(
180184
"--max-states",
181185
type=int,
182186
default=8,
183187
help="""Used only when --decoding-method is
184-
fast_beam_search""",
188+
fast_beam_search or fast_beam_search_nbest_oracle""",
185189
)
186190

187191
parser.add_argument(
@@ -199,6 +203,23 @@ def get_parser():
199203
Used only when --decoding_method is greedy_search""",
200204
)
201205

206+
parser.add_argument(
207+
"--num-paths",
208+
type=int,
209+
default=100,
210+
help="""Number of paths for computed nbest oracle WER
211+
when the decoding method is fast_beam_search_nbest_oracle.
212+
""",
213+
)
214+
215+
parser.add_argument(
216+
"--nbest-scale",
217+
type=float,
218+
default=0.5,
219+
help="""Scale applied to lattice scores when computing nbest paths.
220+
Used only when the decoding_method is fast_beam_search_nbest_oracle.
221+
""",
222+
)
202223
return parser
203224

204225

@@ -243,7 +264,8 @@ def decode_one_batch(
243264
for the format of the `batch`.
244265
decoding_graph:
245266
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
246-
only when --decoding_method is fast_beam_search.
267+
only when --decoding_method is
268+
fast_beam_search or fast_beam_search_nbest_oracle.
247269
Returns:
248270
Return the decoding result. See above description for the format of
249271
the returned dict.
@@ -264,7 +286,7 @@ def decode_one_batch(
264286
hyps = []
265287

266288
if params.decoding_method == "fast_beam_search":
267-
hyp_tokens = fast_beam_search(
289+
hyp_tokens = fast_beam_search_one_best(
268290
model=model,
269291
decoding_graph=decoding_graph,
270292
encoder_out=encoder_out,
@@ -275,6 +297,21 @@ def decode_one_batch(
275297
)
276298
for hyp in sp.decode(hyp_tokens):
277299
hyps.append(hyp.split())
300+
elif params.decoding_method == "fast_beam_search_nbest_oracle":
301+
hyp_tokens = fast_beam_search_nbest_oracle(
302+
model=model,
303+
decoding_graph=decoding_graph,
304+
encoder_out=encoder_out,
305+
encoder_out_lens=encoder_out_lens,
306+
beam=params.beam,
307+
max_contexts=params.max_contexts,
308+
max_states=params.max_states,
309+
num_paths=params.num_paths,
310+
ref_texts=sp.encode(supervisions["text"]),
311+
nbest_scale=params.nbest_scale,
312+
)
313+
for hyp in sp.decode(hyp_tokens):
314+
hyps.append(hyp.split())
278315
elif (
279316
params.decoding_method == "greedy_search"
280317
and params.max_sym_per_frame == 1
@@ -328,6 +365,16 @@ def decode_one_batch(
328365
f"max_states_{params.max_states}"
329366
): hyps
330367
}
368+
elif params.decoding_method == "fast_beam_search_nbest_oracle":
369+
return {
370+
(
371+
f"beam_{params.beam}_"
372+
f"max_contexts_{params.max_contexts}_"
373+
f"max_states_{params.max_states}_"
374+
f"num_paths_{params.num_paths}_"
375+
f"nbest_scale_{params.nbest_scale}"
376+
): hyps
377+
}
331378
else:
332379
return {f"beam_size_{params.beam_size}": hyps}
333380

@@ -463,17 +510,30 @@ def main():
463510
"greedy_search",
464511
"beam_search",
465512
"fast_beam_search",
513+
"fast_beam_search_nbest_oracle",
466514
"modified_beam_search",
467515
)
468516
params.res_dir = params.exp_dir / "giga" / params.decoding_method
469517

470-
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
471-
if "fast_beam_search" in params.decoding_method:
518+
if params.iter > 0:
519+
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
520+
else:
521+
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
522+
523+
if params.decoding_method == "fast_beam_search":
524+
params.suffix += f"-beam-{params.beam}"
525+
params.suffix += f"-max-contexts-{params.max_contexts}"
526+
params.suffix += f"-max-states-{params.max_states}"
527+
elif params.decoding_method == "fast_beam_search_nbest_oracle":
472528
params.suffix += f"-beam-{params.beam}"
473529
params.suffix += f"-max-contexts-{params.max_contexts}"
474530
params.suffix += f"-max-states-{params.max_states}"
531+
params.suffix += f"-num-paths-{params.num_paths}"
532+
params.suffix += f"-nbest-scale-{params.nbest_scale}"
475533
elif "beam_search" in params.decoding_method:
476-
params.suffix += f"-beam-{params.beam_size}"
534+
params.suffix += (
535+
f"-{params.decoding_method}-beam-size-{params.beam_size}"
536+
)
477537
else:
478538
params.suffix += f"-context-{params.context_size}"
479539
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -490,17 +550,30 @@ def main():
490550
sp = spm.SentencePieceProcessor()
491551
sp.load(params.bpe_model)
492552

493-
# <blk> is defined in local/train_bpe_model.py
553+
# <blk> and <unk> is defined in local/train_bpe_model.py
494554
params.blank_id = sp.piece_to_id("<blk>")
555+
params.unk_id = sp.unk_id()
495556
params.vocab_size = sp.get_piece_size()
496557

497558
logging.info(params)
498559

499560
logging.info("About to create model")
500561
model = get_transducer_model(params)
501562

502-
if params.avg_last_n > 0:
503-
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
563+
if params.iter > 0:
564+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
565+
: params.avg
566+
]
567+
if len(filenames) == 0:
568+
raise ValueError(
569+
f"No checkpoints found for"
570+
f" --iter {params.iter}, --avg {params.avg}"
571+
)
572+
elif len(filenames) < params.avg:
573+
raise ValueError(
574+
f"Not enough checkpoints ({len(filenames)}) found for"
575+
f" --iter {params.iter}, --avg {params.avg}"
576+
)
504577
logging.info(f"averaging {filenames}")
505578
model.to(device)
506579
model.load_state_dict(average_checkpoints(filenames, device=device))
@@ -519,13 +592,17 @@ def main():
519592
model.to(device)
520593
model.eval()
521594
model.device = device
595+
model.unk_id = params.unk_id
522596

523597
# In beam_search.py, we are using model.decoder() and model.joiner(),
524598
# so we have to switch to the branch for the GigaSpeech dataset.
525599
model.decoder = model.decoder_giga
526600
model.joiner = model.joiner_giga
527601

528-
if params.decoding_method == "fast_beam_search":
602+
if params.decoding_method in (
603+
"fast_beam_search",
604+
"fast_beam_search_nbest_oracle",
605+
):
529606
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
530607
else:
531608
decoding_graph = None

0 commit comments

Comments
 (0)