Skip to content

Commit d947ef7

Browse files
Add batched streaming beam search for RNN-T (mALSD+mAES) and TDT (mALSD) (#15753)
* add sreaming beam searched with tests Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * fix kenlm tests Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up refactor cudagraphs, parity with greedy Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up tests Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * apply black formattinh Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * minor fix Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * minor fix in comments Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * revert contextsize changes Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * rm alignments from returns Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * fix circular import Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * apply isort and black formatting Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * merged state management classes + add tests to ci Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * clean up Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * fix tdt timestamps Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> * apply isort and blacj Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com> --------- Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>
1 parent a409dfd commit d947ef7

13 files changed

Lines changed: 1684 additions & 257 deletions

examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
7676
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
7777
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
78+
from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer
79+
from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
80+
from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer
7881
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
7982
GreedyBatchedLabelLoopingComputerBase,
8083
)
@@ -250,11 +253,10 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
250253
" non-simulated streaming decoding for now."
251254
)
252255
else:
253-
# real streaming decoding: only greedy_batch, label-looping
254256
with open_dict(cfg.decoding):
255-
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
257+
if cfg.decoding.strategy == "greedy_batch" and cfg.decoding.greedy.loop_labels is not True:
256258
raise NotImplementedError(
257-
"This script currently supports only `greedy_batch` strategy with Label-Looping algorithm"
259+
"This script supports `greedy_batch` strategy only with Label-Looping algorithm"
258260
)
259261
cfg.decoding.tdt_include_token_duration = cfg.timestamps
260262
cfg.decoding.greedy.preserve_alignments = False
@@ -303,7 +305,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
303305
asr_model.eval()
304306

305307
try:
306-
decoding_computer: GreedyBatchedLabelLoopingComputerBase | None = asr_model.decoding.decoding.decoding_computer
308+
if cfg.decoding.strategy == "greedy_batch":
309+
decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
310+
elif cfg.decoding.strategy == "malsd_batch":
311+
decoding_computer = asr_model.decoding.decoding.decoding_computer
312+
elif cfg.decoding.strategy == "maes_batch":
313+
decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding.decoding_computer
314+
else:
315+
raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}")
307316
except AttributeError:
308317
decoding_computer = None
309318

@@ -425,6 +434,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
425434
rest_audio_lengths = audio_batch_lengths.clone()
426435
encoder_output_aggregated: DynamicLengthTensor | None = None
427436

437+
is_beam_search = isinstance(
438+
decoding_computer,
439+
(ModifiedALSDBatchedRNNTComputer, ModifiedAESBatchedRNNTComputer, ModifiedALSDBatchedTDTComputer),
440+
)
441+
428442
# iterate over audio samples
429443
while left_sample < audio_batch.shape[1]:
430444
# add samples to buffer
@@ -472,18 +486,36 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
472486
)
473487
encoder_output_aggregated.append_(data=encoder_output, lengths=encoder_output_len_to_decode)
474488
else:
475-
# decode only chunk frames
476-
chunk_batched_hyps, state = decoding_computer(
477-
x=encoder_output,
478-
out_len=encoder_output_len_to_decode,
479-
prev_batched_state=state,
480-
multi_biasing_ids=multi_biasing_ids,
481-
)
482-
# merge hyps with previous hyps
483-
if current_batched_hyps is None:
484-
current_batched_hyps = chunk_batched_hyps
489+
if not is_beam_search:
490+
# decode only chunk frames
491+
chunk_batched_hyps, state = decoding_computer(
492+
x=encoder_output,
493+
out_len=encoder_output_len_to_decode,
494+
prev_batched_state=state,
495+
multi_biasing_ids=multi_biasing_ids,
496+
)
497+
498+
# merge hyps with previous hyps
499+
if current_batched_hyps is None:
500+
current_batched_hyps = chunk_batched_hyps
501+
else:
502+
current_batched_hyps.merge_(chunk_batched_hyps)
485503
else:
486-
current_batched_hyps.merge_(chunk_batched_hyps)
504+
chunk_batched_hyps, state = decoding_computer(
505+
x=encoder_output,
506+
out_len=encoder_output_len_to_decode,
507+
prev_batched_state=state,
508+
)
509+
# flatten_ to flatten the prefix tree and link beams to prior chunks in merge_ using root_ptrs.
510+
chunk_root_ptrs = chunk_batched_hyps.flatten_()
511+
if current_batched_hyps is None:
512+
current_batched_hyps = chunk_batched_hyps
513+
else:
514+
current_batched_hyps.merge_(
515+
chunk_batched_hyps,
516+
is_chunk_continuation=True,
517+
boundary_prev_ptr=chunk_root_ptrs,
518+
)
487519

488520
# move to next sample
489521
rest_audio_lengths -= chunk_lengths_batch
@@ -493,13 +525,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
493525
if use_simulated_decoding:
494526
# decode aggregated streaming encoder output
495527
if decoding_computer is not None:
496-
current_batched_hyps, _ = decoding_computer(
497-
x=encoder_output_aggregated.data,
498-
out_len=encoder_output_aggregated.lengths,
499-
prev_batched_state=state,
500-
multi_biasing_ids=multi_biasing_ids,
501-
)
502-
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
528+
if not is_beam_search:
529+
current_batched_hyps, _ = decoding_computer(
530+
x=encoder_output_aggregated.data,
531+
out_len=encoder_output_aggregated.lengths,
532+
prev_batched_state=state,
533+
multi_biasing_ids=multi_biasing_ids,
534+
)
535+
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
536+
else:
537+
current_batched_hyps, _ = decoding_computer(
538+
x=encoder_output_aggregated.data,
539+
out_len=encoder_output_aggregated.lengths,
540+
prev_batched_state=state,
541+
)
542+
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
503543
else:
504544
# no decoding computer, fallback to `asr_model.decoding.decoding`
505545
(cur_hyps,) = asr_model.decoding.decoding(
@@ -508,7 +548,10 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
508548
)
509549
all_hyps.extend(cur_hyps)
510550
else:
511-
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
551+
if not is_beam_search:
552+
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
553+
else:
554+
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
512555

513556
# remove biasing requests from the decoder
514557
if use_per_stream_biasing and audio_data.biasing_requests is not None:

nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,6 @@ def __init__(
15831583
pruning_mode: mode for pruning hypotheses with LM
15841584
allow_cuda_graphs: whether to allow CUDA graphs
15851585
return_best_hypothesis: whether to return the best hypothesis or N-best hypotheses
1586-
tokenizer: tokenizer for the model
15871586
"""
15881587

15891588
super().__init__()
@@ -1604,7 +1603,7 @@ def __init__(
16041603
if search_type == "malsd_batch":
16051604
# Depending on availability of `blank_as_pad` support
16061605
# switch between more efficient batch decoding technique
1607-
self._decoding_computer = ModifiedALSDBatchedRNNTComputer(
1606+
self.decoding_computer = ModifiedALSDBatchedRNNTComputer(
16081607
decoder=self.decoder,
16091608
joint=self.joint,
16101609
beam_size=self.beam_size,
@@ -1618,7 +1617,7 @@ def __init__(
16181617
allow_cuda_graphs=allow_cuda_graphs,
16191618
)
16201619
elif search_type == "maes_batch":
1621-
self._decoding_computer = ModifiedAESBatchedRNNTComputer(
1620+
self.decoding_computer = ModifiedAESBatchedRNNTComputer(
16221621
decoder=self.decoder,
16231622
joint=self.joint,
16241623
beam_size=self.beam_size,
@@ -1636,14 +1635,14 @@ def __init__(
16361635

16371636
def disable_cuda_graphs(self) -> bool:
16381637
"""Disable CUDA graphs (e.g., for decoding in training)"""
1639-
if isinstance(self._decoding_computer, WithOptionalCudaGraphs):
1640-
return self._decoding_computer.disable_cuda_graphs()
1638+
if isinstance(self.decoding_computer, WithOptionalCudaGraphs):
1639+
return self.decoding_computer.disable_cuda_graphs()
16411640
return False
16421641

16431642
def maybe_enable_cuda_graphs(self) -> bool:
16441643
"""Enable CUDA graphs (if allowed)"""
1645-
if isinstance(self._decoding_computer, WithOptionalCudaGraphs):
1646-
return self._decoding_computer.maybe_enable_cuda_graphs()
1644+
if isinstance(self.decoding_computer, WithOptionalCudaGraphs):
1645+
return self.decoding_computer.maybe_enable_cuda_graphs()
16471646
return False
16481647

16491648
@property
@@ -1690,7 +1689,7 @@ def forward(
16901689
self.joint.eval()
16911690

16921691
inseq = encoder_output # [B, T, D]
1693-
batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen)
1692+
batched_beam_hyps, _ = self.decoding_computer(x=inseq, out_len=logitlen)
16941693

16951694
batch_size = encoder_output.shape[0]
16961695
if self.return_best_hypothesis:

0 commit comments

Comments
 (0)