69
69
from asr_datamodule import AsrDataModule
70
70
from beam_search import (
71
71
beam_search ,
72
- fast_beam_search ,
72
+ fast_beam_search_nbest_oracle ,
73
+ fast_beam_search_one_best ,
73
74
greedy_search ,
74
75
greedy_search_batch ,
75
76
modified_beam_search ,
@@ -100,27 +101,28 @@ def get_parser():
100
101
"--epoch" ,
101
102
type = int ,
102
103
default = 28 ,
103
- help = "It specifies the checkpoint to use for decoding."
104
- "Note: Epoch counts from 0." ,
104
+ help = """It specifies the checkpoint to use for decoding.
105
+ Note: Epoch counts from 0.
106
+ You can specify --avg to use more checkpoints for model averaging.""" ,
105
107
)
108
+
106
109
parser .add_argument (
107
- "--avg " ,
110
+ "--iter " ,
108
111
type = int ,
109
- default = 15 ,
110
- help = "Number of checkpoints to average. Automatically select "
111
- "consecutive checkpoints before the checkpoint specified by "
112
- "'--epoch'. " ,
112
+ default = 0 ,
113
+ help = """If positive, --epoch is ignored and it
114
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
115
+ You can specify --avg to use more checkpoints for model averaging.
116
+ """ ,
113
117
)
114
118
115
119
parser .add_argument (
116
- "--avg-last-n " ,
120
+ "--avg" ,
117
121
type = int ,
118
- default = 0 ,
119
- help = """If positive, --epoch and --avg are ignored and it
120
- will use the last n checkpoints exp_dir/checkpoint-xxx.pt
121
- where xxx is the number of processed batches while
122
- saving that checkpoint.
123
- """ ,
122
+ default = 15 ,
123
+ help = "Number of checkpoints to average. Automatically select "
124
+ "consecutive checkpoints before the checkpoint specified by "
125
+ "'--epoch' and '--iter'" ,
124
126
)
125
127
126
128
parser .add_argument (
@@ -146,6 +148,7 @@ def get_parser():
146
148
- beam_search
147
149
- modified_beam_search
148
150
- fast_beam_search
151
+ - fast_beam_search_nbest_oracle
149
152
""" ,
150
153
)
151
154
@@ -165,23 +168,24 @@ def get_parser():
165
168
help = """A floating point value to calculate the cutoff score during beam
166
169
search (i.e., `cutoff = max-score - beam`), which is the same as the
167
170
`beam` in Kaldi.
168
- Used only when --decoding-method is fast_beam_search""" ,
171
+ Used only when --decoding-method is
172
+ fast_beam_search or fast_beam_search_nbest_oracle""" ,
169
173
)
170
174
171
175
parser .add_argument (
172
176
"--max-contexts" ,
173
177
type = int ,
174
178
default = 4 ,
175
179
help = """Used only when --decoding-method is
176
- fast_beam_search""" ,
180
+ fast_beam_search or fast_beam_search_nbest_oracle """ ,
177
181
)
178
182
179
183
parser .add_argument (
180
184
"--max-states" ,
181
185
type = int ,
182
186
default = 8 ,
183
187
help = """Used only when --decoding-method is
184
- fast_beam_search""" ,
188
+ fast_beam_search or fast_beam_search_nbest_oracle """ ,
185
189
)
186
190
187
191
parser .add_argument (
@@ -199,6 +203,23 @@ def get_parser():
199
203
Used only when --decoding_method is greedy_search""" ,
200
204
)
201
205
206
+ parser .add_argument (
207
+ "--num-paths" ,
208
+ type = int ,
209
+ default = 100 ,
210
+ help = """Number of paths for computed nbest oracle WER
211
+ when the decoding method is fast_beam_search_nbest_oracle.
212
+ """ ,
213
+ )
214
+
215
+ parser .add_argument (
216
+ "--nbest-scale" ,
217
+ type = float ,
218
+ default = 0.5 ,
219
+ help = """Scale applied to lattice scores when computing nbest paths.
220
+ Used only when the decoding_method is fast_beam_search_nbest_oracle.
221
+ """ ,
222
+ )
202
223
return parser
203
224
204
225
@@ -243,7 +264,8 @@ def decode_one_batch(
243
264
for the format of the `batch`.
244
265
decoding_graph:
245
266
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
246
- only when --decoding_method is fast_beam_search.
267
+ only when --decoding_method is
268
+ fast_beam_search or fast_beam_search_nbest_oracle.
247
269
Returns:
248
270
Return the decoding result. See above description for the format of
249
271
the returned dict.
@@ -264,7 +286,7 @@ def decode_one_batch(
264
286
hyps = []
265
287
266
288
if params .decoding_method == "fast_beam_search" :
267
- hyp_tokens = fast_beam_search (
289
+ hyp_tokens = fast_beam_search_one_best (
268
290
model = model ,
269
291
decoding_graph = decoding_graph ,
270
292
encoder_out = encoder_out ,
@@ -275,6 +297,21 @@ def decode_one_batch(
275
297
)
276
298
for hyp in sp .decode (hyp_tokens ):
277
299
hyps .append (hyp .split ())
300
+ elif params .decoding_method == "fast_beam_search_nbest_oracle" :
301
+ hyp_tokens = fast_beam_search_nbest_oracle (
302
+ model = model ,
303
+ decoding_graph = decoding_graph ,
304
+ encoder_out = encoder_out ,
305
+ encoder_out_lens = encoder_out_lens ,
306
+ beam = params .beam ,
307
+ max_contexts = params .max_contexts ,
308
+ max_states = params .max_states ,
309
+ num_paths = params .num_paths ,
310
+ ref_texts = sp .encode (supervisions ["text" ]),
311
+ nbest_scale = params .nbest_scale ,
312
+ )
313
+ for hyp in sp .decode (hyp_tokens ):
314
+ hyps .append (hyp .split ())
278
315
elif (
279
316
params .decoding_method == "greedy_search"
280
317
and params .max_sym_per_frame == 1
@@ -328,6 +365,16 @@ def decode_one_batch(
328
365
f"max_states_{ params .max_states } "
329
366
): hyps
330
367
}
368
+ elif params .decoding_method == "fast_beam_search_nbest_oracle" :
369
+ return {
370
+ (
371
+ f"beam_{ params .beam } _"
372
+ f"max_contexts_{ params .max_contexts } _"
373
+ f"max_states_{ params .max_states } _"
374
+ f"num_paths_{ params .num_paths } _"
375
+ f"nbest_scale_{ params .nbest_scale } "
376
+ ): hyps
377
+ }
331
378
else :
332
379
return {f"beam_size_{ params .beam_size } " : hyps }
333
380
@@ -463,17 +510,30 @@ def main():
463
510
"greedy_search" ,
464
511
"beam_search" ,
465
512
"fast_beam_search" ,
513
+ "fast_beam_search_nbest_oracle" ,
466
514
"modified_beam_search" ,
467
515
)
468
516
params .res_dir = params .exp_dir / "giga" / params .decoding_method
469
517
470
- params .suffix = f"epoch-{ params .epoch } -avg-{ params .avg } "
471
- if "fast_beam_search" in params .decoding_method :
518
+ if params .iter > 0 :
519
+ params .suffix = f"iter-{ params .iter } -avg-{ params .avg } "
520
+ else :
521
+ params .suffix = f"epoch-{ params .epoch } -avg-{ params .avg } "
522
+
523
+ if params .decoding_method == "fast_beam_search" :
524
+ params .suffix += f"-beam-{ params .beam } "
525
+ params .suffix += f"-max-contexts-{ params .max_contexts } "
526
+ params .suffix += f"-max-states-{ params .max_states } "
527
+ elif params .decoding_method == "fast_beam_search_nbest_oracle" :
472
528
params .suffix += f"-beam-{ params .beam } "
473
529
params .suffix += f"-max-contexts-{ params .max_contexts } "
474
530
params .suffix += f"-max-states-{ params .max_states } "
531
+ params .suffix += f"-num-paths-{ params .num_paths } "
532
+ params .suffix += f"-nbest-scale-{ params .nbest_scale } "
475
533
elif "beam_search" in params .decoding_method :
476
- params .suffix += f"-beam-{ params .beam_size } "
534
+ params .suffix += (
535
+ f"-{ params .decoding_method } -beam-size-{ params .beam_size } "
536
+ )
477
537
else :
478
538
params .suffix += f"-context-{ params .context_size } "
479
539
params .suffix += f"-max-sym-per-frame-{ params .max_sym_per_frame } "
@@ -490,17 +550,30 @@ def main():
490
550
sp = spm .SentencePieceProcessor ()
491
551
sp .load (params .bpe_model )
492
552
493
- # <blk> is defined in local/train_bpe_model.py
553
+ # <blk> and <unk> is defined in local/train_bpe_model.py
494
554
params .blank_id = sp .piece_to_id ("<blk>" )
555
+ params .unk_id = sp .unk_id ()
495
556
params .vocab_size = sp .get_piece_size ()
496
557
497
558
logging .info (params )
498
559
499
560
logging .info ("About to create model" )
500
561
model = get_transducer_model (params )
501
562
502
- if params .avg_last_n > 0 :
503
- filenames = find_checkpoints (params .exp_dir )[: params .avg_last_n ]
563
+ if params .iter > 0 :
564
+ filenames = find_checkpoints (params .exp_dir , iteration = - params .iter )[
565
+ : params .avg
566
+ ]
567
+ if len (filenames ) == 0 :
568
+ raise ValueError (
569
+ f"No checkpoints found for"
570
+ f" --iter { params .iter } , --avg { params .avg } "
571
+ )
572
+ elif len (filenames ) < params .avg :
573
+ raise ValueError (
574
+ f"Not enough checkpoints ({ len (filenames )} ) found for"
575
+ f" --iter { params .iter } , --avg { params .avg } "
576
+ )
504
577
logging .info (f"averaging { filenames } " )
505
578
model .to (device )
506
579
model .load_state_dict (average_checkpoints (filenames , device = device ))
@@ -519,13 +592,17 @@ def main():
519
592
model .to (device )
520
593
model .eval ()
521
594
model .device = device
595
+ model .unk_id = params .unk_id
522
596
523
597
# In beam_search.py, we are using model.decoder() and model.joiner(),
524
598
# so we have to switch to the branch for the GigaSpeech dataset.
525
599
model .decoder = model .decoder_giga
526
600
model .joiner = model .joiner_giga
527
601
528
- if params .decoding_method == "fast_beam_search" :
602
+ if params .decoding_method in (
603
+ "fast_beam_search" ,
604
+ "fast_beam_search_nbest_oracle" ,
605
+ ):
529
606
decoding_graph = k2 .trivial_graph (params .vocab_size - 1 , device = device )
530
607
else :
531
608
decoding_graph = None
0 commit comments