Skip to content

Commit 6d3901d

Browse files
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan/streaming-beam-search
2 parents 1e86cc3 + 160a742 commit 6d3901d

4 files changed

Lines changed: 647 additions & 232 deletions

File tree

docker/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ case "$(nemo-cuda-flavor)" in
130130
esac
131131
uv pip install --index-url "${torchcodec_index}" torchcodec
132132
EOF
133-
COPY nemo /workspace/nemo
134133

135134
FROM base-image AS automodel-deps
136135
ARG GPU_TARGET=h100plus
@@ -291,7 +290,7 @@ if [ "${INSTALL_FFMPEG}" = "true" ]; then
291290
fi
292291
EOF
293292

294-
ENV NEMO_HOME="/home/TestData/nemo_home"
293+
COPY nemo /workspace/nemo
295294

296295
# NOTICES.txt file points to where the OSS source code is archived
297296
RUN echo "This distribution includes open source which is archived at the following URL: https://opensource.nvidia.com/oss/teams/nvidia/nemo/${RC_DATE}:linux-${TARGETARCH}/index.html" > NOTICES.txt && \

examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py

Lines changed: 139 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from nemo.collections.asr.parts.utils.streaming_utils import (
8888
AudioBatch,
8989
ContextSize,
90+
DynamicLengthTensor,
9091
SimpleAudioDataset,
9192
StreamingBatchedAudioBuffer,
9293
)
@@ -158,6 +159,12 @@ class TranscriptionConfig:
158159
decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig)
159160
# Per-utterance biasing with biasing config in the manifest
160161
use_per_stream_biasing: bool = False
162+
# simulated decoding (False by default) for faster experiments
163+
# + experiments with different decoding algorithms not yet implemented in streaming
164+
# encoder is evaluated on chunks, output is concatenated and decoded at one step
165+
# expected to provide the same results if the decoding strategy supports
166+
# streaming decoding without additional heuristics (e.g., pruning between steps)
167+
simulated: bool = False
161168

162169
timestamps: bool = False # output timestamps
163170
confidence: bool = False # output word confidence
@@ -233,36 +240,47 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
233240
asr_model.to(compute_dtype)
234241

235242
use_per_stream_biasing = cfg.use_per_stream_biasing
243+
use_simulated_decoding = cfg.simulated
236244

237245
# Change Decoding Config
238-
is_tdt_model = cfg.decoding.get("durations", None) not in (None, [])
239-
with open_dict(cfg.decoding):
240-
if cfg.decoding.strategy == "greedy_batch":
241-
if cfg.decoding.greedy.loop_labels is not True:
246+
if use_simulated_decoding:
247+
# simulated decoding: any config allowed, do not change config
248+
with open_dict(cfg.decoding):
249+
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
250+
logging.warning(
251+
f"Using {cfg.decoding.strategy} in simulated decoding."
252+
" Only greedy_batch with label-looping fully supports"
253+
" non-simulated streaming decoding for now."
254+
)
255+
else:
256+
is_tdt_model = cfg.decoding.get("durations", None) not in (None, [])
257+
with open_dict(cfg.decoding):
258+
if cfg.decoding.strategy == "greedy_batch":
259+
if cfg.decoding.greedy.loop_labels is not True:
260+
raise NotImplementedError(
261+
"This script supports `greedy_batch` strategy only with Label-Looping algorithm"
262+
)
263+
cfg.decoding.greedy.preserve_alignments = False
264+
elif cfg.decoding.strategy == "malsd_batch":
265+
pass
266+
elif cfg.decoding.strategy == "maes_batch":
267+
if is_tdt_model:
268+
raise NotImplementedError("`maes_batch` is RNN-T only; use `malsd_batch` for TDT models.")
269+
else:
242270
raise NotImplementedError(
243-
"This script supports `greedy_batch` strategy only with Label-Looping algorithm"
271+
f"Unsupported decoding strategy `{cfg.decoding.strategy}`. "
272+
"Supported: `greedy_batch`, `malsd_batch`, `maes_batch` (RNN-T only)."
244273
)
274+
cfg.decoding.tdt_include_token_duration = cfg.timestamps
245275
cfg.decoding.greedy.preserve_alignments = False
246-
elif cfg.decoding.strategy == "malsd_batch":
247-
pass
248-
elif cfg.decoding.strategy == "maes_batch":
249-
if is_tdt_model:
250-
raise NotImplementedError("`maes_batch` is RNN-T only; use `malsd_batch` for TDT models.")
251-
else:
252-
raise NotImplementedError(
253-
f"Unsupported decoding strategy `{cfg.decoding.strategy}`. "
254-
"Supported: `greedy_batch`, `malsd_batch`, `maes_batch` (RNN-T only)."
255-
)
256-
cfg.decoding.tdt_include_token_duration = cfg.timestamps
257-
cfg.decoding.greedy.preserve_alignments = False
258-
cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference.
259-
cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only
260-
if use_per_stream_biasing:
261-
cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing
262-
if cfg.confidence:
263-
cfg.decoding.greedy.preserve_frame_confidence = True
264-
cfg.decoding.confidence_cfg.preserve_frame_confidence = True
265-
cfg.decoding.confidence_cfg.preserve_word_confidence = True
276+
cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference.
277+
cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only
278+
if use_per_stream_biasing:
279+
cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing
280+
if cfg.confidence:
281+
cfg.decoding.greedy.preserve_frame_confidence = True
282+
cfg.decoding.confidence_cfg.preserve_frame_confidence = True
283+
cfg.decoding.confidence_cfg.preserve_word_confidence = True
266284

267285
# Setup decoding strategy
268286
if hasattr(asr_model, 'change_decoding_strategy'):
@@ -299,16 +317,20 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
299317
asr_model.preprocessor.featurizer.pad_to = 0
300318
asr_model.eval()
301319

302-
# Get decoding computer based on strategy. Beam-search strategies expose the
303-
# underlying computer via the private ``_decoding_computer`` attribute.
304-
if cfg.decoding.strategy == "greedy_batch":
305-
decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
306-
elif cfg.decoding.strategy == "malsd_batch":
307-
decoding_computer = asr_model.decoding.decoding.decoding_computer
308-
elif cfg.decoding.strategy == "maes_batch":
309-
decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding.decoding_computer
310-
else:
311-
raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}")
320+
try:
321+
if cfg.decoding.strategy == "greedy_batch":
322+
decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
323+
elif cfg.decoding.strategy == "malsd_batch":
324+
decoding_computer = asr_model.decoding.decoding.decoding_computer
325+
elif cfg.decoding.strategy == "maes_batch":
326+
decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding.decoding_computer
327+
else:
328+
raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}")
329+
except AttributeError:
330+
decoding_computer = None
331+
332+
if (not use_simulated_decoding) or use_per_stream_biasing:
333+
assert decoding_computer is not None
312334

313335
audio_sample_rate = model_cfg.preprocessor['sample_rate']
314336

@@ -423,6 +445,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
423445
device=device,
424446
)
425447
rest_audio_lengths = audio_batch_lengths.clone()
448+
encoder_output_aggregated: DynamicLengthTensor | None = None
426449

427450
is_beam_search = isinstance(
428451
decoding_computer,
@@ -458,64 +481,102 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
458481
encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples)
459482
# remove left context
460483
encoder_output = encoder_output[:, encoder_context.left :]
461-
462-
# decode only chunk frames
463-
out_len = torch.where(
484+
encoder_output_len_to_decode = torch.where(
464485
is_last_chunk_batch,
465486
encoder_output_len - encoder_context_batch.left,
466487
encoder_context_batch.chunk,
467488
)
468-
if is_beam_search:
469-
# Beam-search computers don't accept ``multi_biasing_ids`` yet.
470-
chunk_batched_hyps, state = decoding_computer(
471-
x=encoder_output, out_len=out_len, prev_batched_state=state
472-
)
473-
else:
474-
chunk_batched_hyps, _, state = decoding_computer(
475-
x=encoder_output,
476-
out_len=out_len,
477-
prev_batched_state=state,
478-
multi_biasing_ids=multi_biasing_ids,
479-
)
480489

481-
# Accumulate hypotheses across chunks.
482-
if is_beam_search:
483-
# Flatten this chunk's prefix tree and thread the cross-chunk beam
484-
# permutation (``root_ptrs``) into the accumulator so the final
485-
# ``flatten_sort_`` walks back through the right beam history.
486-
# ``chunk_batched_hyps`` is the per-chunk BatchedBeamHyps (the
487-
# cross-chunk per-beam scalars live on ``state.beam_state`` now).
488-
chunk_root_ptrs = chunk_batched_hyps.flatten_()
489-
if current_batched_hyps is None:
490-
current_batched_hyps = chunk_batched_hyps
491-
else:
492-
current_batched_hyps.merge_(
493-
chunk_batched_hyps,
494-
is_chunk_continuation=True,
495-
boundary_prev_ptr=chunk_root_ptrs,
490+
if use_simulated_decoding:
491+
# store encoder output (accumulate)
492+
if encoder_output_aggregated is None:
493+
encoder_output_aggregated = DynamicLengthTensor(
494+
batch_size=batch_size,
495+
init_length=encoder_output.shape[1],
496+
dim_shape=encoder_output.shape[2],
497+
device=device,
498+
dtype=compute_dtype,
496499
)
500+
encoder_output_aggregated.append_(data=encoder_output, lengths=encoder_output_len_to_decode)
497501
else:
498-
if current_batched_hyps is None:
499-
current_batched_hyps = chunk_batched_hyps
502+
if not is_beam_search:
503+
# decode only chunk frames
504+
chunk_batched_hyps, state = decoding_computer(
505+
x=encoder_output,
506+
out_len=encoder_output_len_to_decode,
507+
prev_batched_state=state,
508+
multi_biasing_ids=multi_biasing_ids,
509+
)
510+
511+
# merge hyps with previous hyps
512+
if current_batched_hyps is None:
513+
current_batched_hyps = chunk_batched_hyps
514+
else:
515+
current_batched_hyps.merge_(chunk_batched_hyps)
500516
else:
501-
current_batched_hyps.merge_(chunk_batched_hyps)
517+
chunk_batched_hyps, state = decoding_computer(
518+
x=encoder_output,
519+
out_len=encoder_output_len_to_decode,
520+
prev_batched_state=state,
521+
)
522+
# Flatten this chunk's prefix tree and thread the cross-chunk beam
523+
# permutation (``root_ptrs``) into the accumulator so the final
524+
# ``flatten_sort_`` walks back through the right beam history.
525+
# ``chunk_batched_hyps`` is the per-chunk BatchedBeamHyps (the
526+
# cross-chunk per-beam scalars live on ``state.beam_state`` now).
527+
chunk_root_ptrs = chunk_batched_hyps.flatten_()
528+
if current_batched_hyps is None:
529+
current_batched_hyps = chunk_batched_hyps
530+
else:
531+
current_batched_hyps.merge_(
532+
chunk_batched_hyps,
533+
is_chunk_continuation=True,
534+
boundary_prev_ptr=chunk_root_ptrs,
535+
)
536+
502537

503538
# move to next sample
504539
rest_audio_lengths -= chunk_lengths_batch
505540
left_sample = right_sample
506541
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk
507542

508-
# Convert batched hypotheses to list
509-
if is_beam_search:
510-
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
543+
if use_simulated_decoding:
544+
# decode aggregated streaming encoder output
545+
if decoding_computer is not None:
546+
if not is_beam_search:
547+
current_batched_hyps, _ = decoding_computer(
548+
x=encoder_output_aggregated.data,
549+
out_len=encoder_output_aggregated.lengths,
550+
prev_batched_state=state,
551+
multi_biasing_ids=multi_biasing_ids,
552+
)
553+
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
554+
else:
555+
current_batched_hyps, _ = decoding_computer(
556+
x=encoder_output_aggregated.data,
557+
out_len=encoder_output_aggregated.lengths,
558+
prev_batched_state=state,
559+
)
560+
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
561+
else:
562+
# no decoding computer, fallback to `asr_model.decoding.decoding`
563+
(cur_hyps,) = asr_model.decoding.decoding(
564+
encoder_output=encoder_output_aggregated.data.transpose(1, 2),
565+
encoded_lengths=encoder_output_aggregated.lengths,
566+
)
567+
all_hyps.extend(cur_hyps)
511568
else:
512-
# remove biasing requests from the decoder
513-
if use_per_stream_biasing and audio_data.biasing_requests is not None:
514-
for request in audio_data.biasing_requests:
515-
if request is not None and request.multi_model_id is not None:
516-
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
517-
request.multi_model_id = None
518-
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
569+
if not is_beam_search:
570+
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size))
571+
else:
572+
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
573+
574+
# remove biasing requests from the decoder
575+
if use_per_stream_biasing and audio_data.biasing_requests is not None:
576+
for request in audio_data.biasing_requests:
577+
if request is not None and request.multi_model_id is not None:
578+
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
579+
request.multi_model_id = None
519580
timer.stop(device=map_location)
520581

521582
# convert text

0 commit comments

Comments
 (0)