7575from nemo .collections .asr .models import EncDecHybridRNNTCTCModel , EncDecRNNTModel
7676from nemo .collections .asr .parts .context_biasing .biasing_multi_model import BiasingRequestItemConfig
7777from nemo .collections .asr .parts .submodules .rnnt_decoding import RNNTDecodingConfig
78+ from nemo .collections .asr .parts .submodules .rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer
79+ from nemo .collections .asr .parts .submodules .rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
80+ from nemo .collections .asr .parts .submodules .tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer
7881from nemo .collections .asr .parts .submodules .transducer_decoding .label_looping_base import (
7982 GreedyBatchedLabelLoopingComputerBase ,
8083)
8487from nemo .collections .asr .parts .utils .streaming_utils import (
8588 AudioBatch ,
8689 ContextSize ,
90+ DynamicLengthTensor ,
8791 SimpleAudioDataset ,
8892 StreamingBatchedAudioBuffer ,
8993)
@@ -155,8 +159,15 @@ class TranscriptionConfig:
155159 decoding : RNNTDecodingConfig = field (default_factory = RNNTDecodingConfig )
156160 # Per-utterance biasing with biasing config in the manifest
157161 use_per_stream_biasing : bool = False
162+ # simulated decoding (False by default) for faster experiments
163+ # + experiments with different decoding algorithms not yet implemented in streaming
164+ # encoder is evaluated on chunks, output is concatenated and decoded at one step
165+ # expected to provide the same results if the decoding strategy supports
166+ # streaming decoding without additional heuristics (e.g., pruning between steps)
167+ simulated : bool = False
158168
159169 timestamps : bool = False # output timestamps
170+ confidence : bool = False # output word confidence
160171
161172 # Config for word / character error rate calculation
162173 calculate_wer : bool = True
@@ -229,19 +240,34 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
229240 asr_model .to (compute_dtype )
230241
231242 use_per_stream_biasing = cfg .use_per_stream_biasing
243+ use_simulated_decoding = cfg .simulated
232244
233245 # Change Decoding Config
234- with open_dict (cfg .decoding ):
235- if cfg .decoding .strategy != "greedy_batch" or cfg .decoding .greedy .loop_labels is not True :
236- raise NotImplementedError (
237- "This script currently supports only `greedy_batch` strategy with Label-Looping algorithm"
238- )
239- cfg .decoding .tdt_include_token_duration = cfg .timestamps
240- cfg .decoding .greedy .preserve_alignments = False
241- cfg .decoding .fused_batch_size = - 1 # temporarily stop fused batch during inference.
242- cfg .decoding .beam .return_best_hypothesis = True # return and write the best hypothsis only
243- if use_per_stream_biasing :
244- cfg .decoding .greedy .enable_per_stream_biasing = use_per_stream_biasing
246+ if use_simulated_decoding :
247+ # simulated decoding: any config allowed, do not change config
248+ with open_dict (cfg .decoding ):
249+ if cfg .decoding .strategy != "greedy_batch" or cfg .decoding .greedy .loop_labels is not True :
250+ logging .warning (
251+ f"Using { cfg .decoding .strategy } in simulated decoding."
252+ " Only greedy_batch with label-looping fully supports"
253+ " non-simulated streaming decoding for now."
254+ )
255+ else :
256+ with open_dict (cfg .decoding ):
257+ if cfg .decoding .strategy == "greedy_batch" and cfg .decoding .greedy .loop_labels is not True :
258+ raise NotImplementedError (
259+ "This script supports `greedy_batch` strategy only with Label-Looping algorithm"
260+ )
261+ cfg .decoding .tdt_include_token_duration = cfg .timestamps
262+ cfg .decoding .greedy .preserve_alignments = False
263+ cfg .decoding .fused_batch_size = - 1 # temporarily stop fused batch during inference.
264+ cfg .decoding .beam .return_best_hypothesis = True # return and write the best hypothsis only
265+ if use_per_stream_biasing :
266+ cfg .decoding .greedy .enable_per_stream_biasing = use_per_stream_biasing
267+ if cfg .confidence :
268+ cfg .decoding .greedy .preserve_frame_confidence = True
269+ cfg .decoding .confidence_cfg .preserve_frame_confidence = True
270+ cfg .decoding .confidence_cfg .preserve_word_confidence = True
245271
246272 # Setup decoding strategy
247273 if hasattr (asr_model , 'change_decoding_strategy' ):
@@ -278,7 +304,20 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
278304 asr_model .preprocessor .featurizer .pad_to = 0
279305 asr_model .eval ()
280306
281- decoding_computer : GreedyBatchedLabelLoopingComputerBase = asr_model .decoding .decoding .decoding_computer
307+ try :
308+ if cfg .decoding .strategy == "greedy_batch" :
309+ decoding_computer : GreedyBatchedLabelLoopingComputerBase = asr_model .decoding .decoding .decoding_computer
310+ elif cfg .decoding .strategy == "malsd_batch" :
311+ decoding_computer = asr_model .decoding .decoding .decoding_computer
312+ elif cfg .decoding .strategy == "maes_batch" :
313+ decoding_computer : ModifiedAESBatchedRNNTComputer = asr_model .decoding .decoding .decoding_computer
314+ else :
315+ raise ValueError (f"Unsupported decoding strategy: { cfg .decoding .strategy } " )
316+ except AttributeError :
317+ decoding_computer = None
318+
319+ if (not use_simulated_decoding ) or use_per_stream_biasing :
320+ assert decoding_computer is not None
282321
283322 audio_sample_rate = model_cfg .preprocessor ['sample_rate' ]
284323
@@ -393,6 +432,12 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
393432 device = device ,
394433 )
395434 rest_audio_lengths = audio_batch_lengths .clone ()
435+ encoder_output_aggregated : DynamicLengthTensor | None = None
436+
437+ is_beam_search = isinstance (
438+ decoding_computer ,
439+ (ModifiedALSDBatchedRNNTComputer , ModifiedAESBatchedRNNTComputer , ModifiedALSDBatchedTDTComputer ),
440+ )
396441
397442 # iterate over audio samples
398443 while left_sample < audio_batch .shape [1 ]:
@@ -423,36 +468,97 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
423468 encoder_context_batch = buffer .context_size_batch .subsample (factor = encoder_frame2audio_samples )
424469 # remove left context
425470 encoder_output = encoder_output [:, encoder_context .left :]
426-
427- # decode only chunk frames
428- chunk_batched_hyps , _ , state = decoding_computer (
429- x = encoder_output ,
430- out_len = torch .where (
431- is_last_chunk_batch ,
432- encoder_output_len - encoder_context_batch .left ,
433- encoder_context_batch .chunk ,
434- ),
435- prev_batched_state = state ,
436- multi_biasing_ids = multi_biasing_ids ,
471+ encoder_output_len_to_decode = torch .where (
472+ is_last_chunk_batch ,
473+ encoder_output_len - encoder_context_batch .left ,
474+ encoder_context_batch .chunk ,
437475 )
438- # merge hyps with previous hyps
439- if current_batched_hyps is None :
440- current_batched_hyps = chunk_batched_hyps
476+
477+ if use_simulated_decoding :
478+ # store encoder output (accumulate)
479+ if encoder_output_aggregated is None :
480+ encoder_output_aggregated = DynamicLengthTensor (
481+ batch_size = batch_size ,
482+ init_length = encoder_output .shape [1 ],
483+ dim_shape = encoder_output .shape [2 ],
484+ device = device ,
485+ dtype = compute_dtype ,
486+ )
487+ encoder_output_aggregated .append_ (data = encoder_output , lengths = encoder_output_len_to_decode )
441488 else :
442- current_batched_hyps .merge_ (chunk_batched_hyps )
489+ if not is_beam_search :
490+ # decode only chunk frames
491+ chunk_batched_hyps , state = decoding_computer (
492+ x = encoder_output ,
493+ out_len = encoder_output_len_to_decode ,
494+ prev_batched_state = state ,
495+ multi_biasing_ids = multi_biasing_ids ,
496+ )
497+
498+ # merge hyps with previous hyps
499+ if current_batched_hyps is None :
500+ current_batched_hyps = chunk_batched_hyps
501+ else :
502+ current_batched_hyps .merge_ (chunk_batched_hyps )
503+ else :
504+ chunk_batched_hyps , state = decoding_computer (
505+ x = encoder_output ,
506+ out_len = encoder_output_len_to_decode ,
507+ prev_batched_state = state ,
508+ )
509+ # flatten_ to flatten the prefix tree and link beams to prior chunks in merge_ using root_ptrs.
510+ chunk_root_ptrs = chunk_batched_hyps .flatten_ ()
511+ if current_batched_hyps is None :
512+ current_batched_hyps = chunk_batched_hyps
513+ else :
514+ current_batched_hyps .merge_ (
515+ chunk_batched_hyps ,
516+ is_chunk_continuation = True ,
517+ boundary_prev_ptr = chunk_root_ptrs ,
518+ )
443519
444520 # move to next sample
445521 rest_audio_lengths -= chunk_lengths_batch
446522 left_sample = right_sample
447523 right_sample = min (right_sample + context_samples .chunk , audio_batch .shape [1 ]) # add next chunk
448524
525+ if use_simulated_decoding :
526+ # decode aggregated streaming encoder output
527+ if decoding_computer is not None :
528+ if not is_beam_search :
529+ current_batched_hyps , _ = decoding_computer (
530+ x = encoder_output_aggregated .data ,
531+ out_len = encoder_output_aggregated .lengths ,
532+ prev_batched_state = state ,
533+ multi_biasing_ids = multi_biasing_ids ,
534+ )
535+ all_hyps .extend (batched_hyps_to_hypotheses (current_batched_hyps , batch_size = batch_size ))
536+ else :
537+ current_batched_hyps , _ = decoding_computer (
538+ x = encoder_output_aggregated .data ,
539+ out_len = encoder_output_aggregated .lengths ,
540+ prev_batched_state = state ,
541+ )
542+ all_hyps .extend (current_batched_hyps .to_hyps_list (score_norm = True ))
543+ else :
544+ # no decoding computer, fallback to `asr_model.decoding.decoding`
545+ (cur_hyps ,) = asr_model .decoding .decoding (
546+ encoder_output = encoder_output_aggregated .data .transpose (1 , 2 ),
547+ encoded_lengths = encoder_output_aggregated .lengths ,
548+ )
549+ all_hyps .extend (cur_hyps )
550+ else :
551+ if not is_beam_search :
552+ all_hyps .extend (batched_hyps_to_hypotheses (current_batched_hyps , batch_size = batch_size ))
553+ else :
554+ all_hyps .extend (current_batched_hyps .to_hyps_list (score_norm = True ))
555+
449556 # remove biasing requests from the decoder
450557 if use_per_stream_biasing and audio_data .biasing_requests is not None :
451558 for request in audio_data .biasing_requests :
452559 if request is not None and request .multi_model_id is not None :
453560 decoding_computer .biasing_multi_model .remove_model (request .multi_model_id )
454561 request .multi_model_id = None
455- all_hyps .extend (batched_hyps_to_hypotheses (current_batched_hyps , None , batch_size = batch_size ))
456562 timer .stop (device = map_location )
457563
458564 # convert text
@@ -466,6 +572,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
466572 window_stride = asr_model .cfg ['preprocessor' ]['window_stride' ],
467573 )
468574 all_hyps [i ] = hyp
575+ if cfg .confidence :
576+ all_hyps = asr_model .decoding .compute_confidence (all_hyps )
469577
470578 if cfg .sort_by_duration :
471579 # restore order for all_hyps and records (all_hyps are consistent with records)
@@ -475,7 +583,13 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
475583 records , all_hyps = map (list , zip (* order_restored ))
476584
477585 output_filename , pred_text_attr_name = write_transcription (
478- all_hyps , cfg , model_name , filepaths = filepaths , compute_langs = False , timestamps = cfg .timestamps
586+ all_hyps ,
587+ cfg ,
588+ model_name ,
589+ filepaths = filepaths ,
590+ compute_langs = False ,
591+ timestamps = cfg .timestamps ,
592+ confidence = cfg .confidence ,
479593 )
480594 logging .info (f"Finished writing predictions to { output_filename } !" )
481595
0 commit comments