|
87 | 87 | from nemo.collections.asr.parts.utils.streaming_utils import ( |
88 | 88 | AudioBatch, |
89 | 89 | ContextSize, |
| 90 | + DynamicLengthTensor, |
90 | 91 | SimpleAudioDataset, |
91 | 92 | StreamingBatchedAudioBuffer, |
92 | 93 | ) |
@@ -158,6 +159,12 @@ class TranscriptionConfig: |
158 | 159 | decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig) |
159 | 160 | # Per-utterance biasing with biasing config in the manifest |
160 | 161 | use_per_stream_biasing: bool = False |
| 162 | + # simulated decoding (False by default) for faster experiments |
| 163 | + # + experiments with different decoding algorithms not yet implemented in streaming |
| 164 | + # encoder is evaluated on chunks, output is concatenated and decoded at one step |
| 165 | + # expected to provide the same results if the decoding strategy supports |
| 166 | + # streaming decoding without additional heuristics (e.g., pruning between steps) |
| 167 | + simulated: bool = False |
161 | 168 |
|
162 | 169 | timestamps: bool = False # output timestamps |
163 | 170 | confidence: bool = False # output word confidence |
@@ -233,36 +240,47 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: |
233 | 240 | asr_model.to(compute_dtype) |
234 | 241 |
|
235 | 242 | use_per_stream_biasing = cfg.use_per_stream_biasing |
| 243 | + use_simulated_decoding = cfg.simulated |
236 | 244 |
|
237 | 245 | # Change Decoding Config |
238 | | - is_tdt_model = cfg.decoding.get("durations", None) not in (None, []) |
239 | | - with open_dict(cfg.decoding): |
240 | | - if cfg.decoding.strategy == "greedy_batch": |
241 | | - if cfg.decoding.greedy.loop_labels is not True: |
| 246 | + if use_simulated_decoding: |
| 247 | + # simulated decoding: any config allowed, do not change config |
| 248 | + with open_dict(cfg.decoding): |
| 249 | + if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True: |
| 250 | + logging.warning( |
| 251 | + f"Using {cfg.decoding.strategy} in simulated decoding." |
| 252 | + " Only greedy_batch with label-looping fully supports" |
| 253 | + " non-simulated streaming decoding for now." |
| 254 | + ) |
| 255 | + else: |
| 256 | + is_tdt_model = cfg.decoding.get("durations", None) not in (None, []) |
| 257 | + with open_dict(cfg.decoding): |
| 258 | + if cfg.decoding.strategy == "greedy_batch": |
| 259 | + if cfg.decoding.greedy.loop_labels is not True: |
| 260 | + raise NotImplementedError( |
| 261 | + "This script supports `greedy_batch` strategy only with Label-Looping algorithm" |
| 262 | + ) |
| 263 | + cfg.decoding.greedy.preserve_alignments = False |
| 264 | + elif cfg.decoding.strategy == "malsd_batch": |
| 265 | + pass |
| 266 | + elif cfg.decoding.strategy == "maes_batch": |
| 267 | + if is_tdt_model: |
| 268 | + raise NotImplementedError("`maes_batch` is RNN-T only; use `malsd_batch` for TDT models.") |
| 269 | + else: |
242 | 270 | raise NotImplementedError( |
243 | | - "This script supports `greedy_batch` strategy only with Label-Looping algorithm" |
| 271 | + f"Unsupported decoding strategy `{cfg.decoding.strategy}`. " |
| 272 | + "Supported: `greedy_batch`, `malsd_batch`, `maes_batch` (RNN-T only)." |
244 | 273 | ) |
| 274 | + cfg.decoding.tdt_include_token_duration = cfg.timestamps |
245 | 275 | cfg.decoding.greedy.preserve_alignments = False |
246 | | - elif cfg.decoding.strategy == "malsd_batch": |
247 | | - pass |
248 | | - elif cfg.decoding.strategy == "maes_batch": |
249 | | - if is_tdt_model: |
250 | | - raise NotImplementedError("`maes_batch` is RNN-T only; use `malsd_batch` for TDT models.") |
251 | | - else: |
252 | | - raise NotImplementedError( |
253 | | - f"Unsupported decoding strategy `{cfg.decoding.strategy}`. " |
254 | | - "Supported: `greedy_batch`, `malsd_batch`, `maes_batch` (RNN-T only)." |
255 | | - ) |
256 | | - cfg.decoding.tdt_include_token_duration = cfg.timestamps |
257 | | - cfg.decoding.greedy.preserve_alignments = False |
258 | | - cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference. |
259 | | - cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only |
260 | | - if use_per_stream_biasing: |
261 | | - cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing |
262 | | - if cfg.confidence: |
263 | | - cfg.decoding.greedy.preserve_frame_confidence = True |
264 | | - cfg.decoding.confidence_cfg.preserve_frame_confidence = True |
265 | | - cfg.decoding.confidence_cfg.preserve_word_confidence = True |
| 276 | + cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference. |
| 277 | + cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only |
| 278 | + if use_per_stream_biasing: |
| 279 | + cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing |
| 280 | + if cfg.confidence: |
| 281 | + cfg.decoding.greedy.preserve_frame_confidence = True |
| 282 | + cfg.decoding.confidence_cfg.preserve_frame_confidence = True |
| 283 | + cfg.decoding.confidence_cfg.preserve_word_confidence = True |
266 | 284 |
|
267 | 285 | # Setup decoding strategy |
268 | 286 | if hasattr(asr_model, 'change_decoding_strategy'): |
@@ -299,16 +317,20 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: |
299 | 317 | asr_model.preprocessor.featurizer.pad_to = 0 |
300 | 318 | asr_model.eval() |
301 | 319 |
|
302 | | - # Get decoding computer based on strategy. Beam-search strategies expose the |
303 | | - # underlying computer via the private ``_decoding_computer`` attribute. |
304 | | - if cfg.decoding.strategy == "greedy_batch": |
305 | | - decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer |
306 | | - elif cfg.decoding.strategy == "malsd_batch": |
307 | | - decoding_computer = asr_model.decoding.decoding.decoding_computer |
308 | | - elif cfg.decoding.strategy == "maes_batch": |
309 | | - decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding.decoding_computer |
310 | | - else: |
311 | | - raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}") |
| 320 | + try: |
| 321 | + if cfg.decoding.strategy == "greedy_batch": |
| 322 | + decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer |
| 323 | + elif cfg.decoding.strategy == "malsd_batch": |
| 324 | + decoding_computer = asr_model.decoding.decoding.decoding_computer |
| 325 | + elif cfg.decoding.strategy == "maes_batch": |
| 326 | + decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding.decoding_computer |
| 327 | + else: |
| 328 | + raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}") |
| 329 | + except AttributeError: |
| 330 | + decoding_computer = None |
| 331 | + |
| 332 | + if (not use_simulated_decoding) or use_per_stream_biasing: |
| 333 | + assert decoding_computer is not None |
312 | 334 |
|
313 | 335 | audio_sample_rate = model_cfg.preprocessor['sample_rate'] |
314 | 336 |
|
@@ -423,6 +445,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: |
423 | 445 | device=device, |
424 | 446 | ) |
425 | 447 | rest_audio_lengths = audio_batch_lengths.clone() |
| 448 | + encoder_output_aggregated: DynamicLengthTensor | None = None |
426 | 449 |
|
427 | 450 | is_beam_search = isinstance( |
428 | 451 | decoding_computer, |
@@ -458,64 +481,102 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: |
458 | 481 | encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples) |
459 | 482 | # remove left context |
460 | 483 | encoder_output = encoder_output[:, encoder_context.left :] |
461 | | - |
462 | | - # decode only chunk frames |
463 | | - out_len = torch.where( |
| 484 | + encoder_output_len_to_decode = torch.where( |
464 | 485 | is_last_chunk_batch, |
465 | 486 | encoder_output_len - encoder_context_batch.left, |
466 | 487 | encoder_context_batch.chunk, |
467 | 488 | ) |
468 | | - if is_beam_search: |
469 | | - # Beam-search computers don't accept ``multi_biasing_ids`` yet. |
470 | | - chunk_batched_hyps, state = decoding_computer( |
471 | | - x=encoder_output, out_len=out_len, prev_batched_state=state |
472 | | - ) |
473 | | - else: |
474 | | - chunk_batched_hyps, _, state = decoding_computer( |
475 | | - x=encoder_output, |
476 | | - out_len=out_len, |
477 | | - prev_batched_state=state, |
478 | | - multi_biasing_ids=multi_biasing_ids, |
479 | | - ) |
480 | 489 |
|
481 | | - # Accumulate hypotheses across chunks. |
482 | | - if is_beam_search: |
483 | | - # Flatten this chunk's prefix tree and thread the cross-chunk beam |
484 | | - # permutation (``root_ptrs``) into the accumulator so the final |
485 | | - # ``flatten_sort_`` walks back through the right beam history. |
486 | | - # ``chunk_batched_hyps`` is the per-chunk BatchedBeamHyps (the |
487 | | - # cross-chunk per-beam scalars live on ``state.beam_state`` now). |
488 | | - chunk_root_ptrs = chunk_batched_hyps.flatten_() |
489 | | - if current_batched_hyps is None: |
490 | | - current_batched_hyps = chunk_batched_hyps |
491 | | - else: |
492 | | - current_batched_hyps.merge_( |
493 | | - chunk_batched_hyps, |
494 | | - is_chunk_continuation=True, |
495 | | - boundary_prev_ptr=chunk_root_ptrs, |
| 490 | + if use_simulated_decoding: |
| 491 | + # store encoder output (accumulate) |
| 492 | + if encoder_output_aggregated is None: |
| 493 | + encoder_output_aggregated = DynamicLengthTensor( |
| 494 | + batch_size=batch_size, |
| 495 | + init_length=encoder_output.shape[1], |
| 496 | + dim_shape=encoder_output.shape[2], |
| 497 | + device=device, |
| 498 | + dtype=compute_dtype, |
496 | 499 | ) |
| 500 | + encoder_output_aggregated.append_(data=encoder_output, lengths=encoder_output_len_to_decode) |
497 | 501 | else: |
498 | | - if current_batched_hyps is None: |
499 | | - current_batched_hyps = chunk_batched_hyps |
| 502 | + if not is_beam_search: |
| 503 | + # decode only chunk frames |
| 504 | + chunk_batched_hyps, state = decoding_computer( |
| 505 | + x=encoder_output, |
| 506 | + out_len=encoder_output_len_to_decode, |
| 507 | + prev_batched_state=state, |
| 508 | + multi_biasing_ids=multi_biasing_ids, |
| 509 | + ) |
| 510 | + |
| 511 | + # merge hyps with previous hyps |
| 512 | + if current_batched_hyps is None: |
| 513 | + current_batched_hyps = chunk_batched_hyps |
| 514 | + else: |
| 515 | + current_batched_hyps.merge_(chunk_batched_hyps) |
500 | 516 | else: |
501 | | - current_batched_hyps.merge_(chunk_batched_hyps) |
| 517 | + chunk_batched_hyps, state = decoding_computer( |
| 518 | + x=encoder_output, |
| 519 | + out_len=encoder_output_len_to_decode, |
| 520 | + prev_batched_state=state, |
| 521 | + ) |
| 522 | + # Flatten this chunk's prefix tree and thread the cross-chunk beam |
| 523 | + # permutation (``root_ptrs``) into the accumulator so the final |
| 524 | + # ``flatten_sort_`` walks back through the right beam history. |
| 525 | + # ``chunk_batched_hyps`` is the per-chunk BatchedBeamHyps (the |
| 526 | + # cross-chunk per-beam scalars live on ``state.beam_state`` now). |
| 527 | + chunk_root_ptrs = chunk_batched_hyps.flatten_() |
| 528 | + if current_batched_hyps is None: |
| 529 | + current_batched_hyps = chunk_batched_hyps |
| 530 | + else: |
| 531 | + current_batched_hyps.merge_( |
| 532 | + chunk_batched_hyps, |
| 533 | + is_chunk_continuation=True, |
| 534 | + boundary_prev_ptr=chunk_root_ptrs, |
| 535 | + ) |
| 536 | + |
502 | 537 |
|
503 | 538 | # move to next sample |
504 | 539 | rest_audio_lengths -= chunk_lengths_batch |
505 | 540 | left_sample = right_sample |
506 | 541 | right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk |
507 | 542 |
|
508 | | - # Convert batched hypotheses to list |
509 | | - if is_beam_search: |
510 | | - all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True)) |
| 543 | + if use_simulated_decoding: |
| 544 | + # decode aggregated streaming encoder output |
| 545 | + if decoding_computer is not None: |
| 546 | + if not is_beam_search: |
| 547 | + current_batched_hyps, _ = decoding_computer( |
| 548 | + x=encoder_output_aggregated.data, |
| 549 | + out_len=encoder_output_aggregated.lengths, |
| 550 | + prev_batched_state=state, |
| 551 | + multi_biasing_ids=multi_biasing_ids, |
| 552 | + ) |
| 553 | + all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size)) |
| 554 | + else: |
| 555 | + current_batched_hyps, _ = decoding_computer( |
| 556 | + x=encoder_output_aggregated.data, |
| 557 | + out_len=encoder_output_aggregated.lengths, |
| 558 | + prev_batched_state=state, |
| 559 | + ) |
| 560 | + all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True)) |
| 561 | + else: |
| 562 | + # no decoding computer, fallback to `asr_model.decoding.decoding` |
| 563 | + (cur_hyps,) = asr_model.decoding.decoding( |
| 564 | + encoder_output=encoder_output_aggregated.data.transpose(1, 2), |
| 565 | + encoded_lengths=encoder_output_aggregated.lengths, |
| 566 | + ) |
| 567 | + all_hyps.extend(cur_hyps) |
511 | 568 | else: |
512 | | - # remove biasing requests from the decoder |
513 | | - if use_per_stream_biasing and audio_data.biasing_requests is not None: |
514 | | - for request in audio_data.biasing_requests: |
515 | | - if request is not None and request.multi_model_id is not None: |
516 | | - decoding_computer.biasing_multi_model.remove_model(request.multi_model_id) |
517 | | - request.multi_model_id = None |
518 | | - all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size)) |
| 569 | + if not is_beam_search: |
| 570 | + all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, batch_size=batch_size)) |
| 571 | + else: |
| 572 | + all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True)) |
| 573 | + |
| 574 | + # remove biasing requests from the decoder |
| 575 | + if use_per_stream_biasing and audio_data.biasing_requests is not None: |
| 576 | + for request in audio_data.biasing_requests: |
| 577 | + if request is not None and request.multi_model_id is not None: |
| 578 | + decoding_computer.biasing_multi_model.remove_model(request.multi_model_id) |
| 579 | + request.multi_model_id = None |
519 | 580 | timer.stop(device=map_location) |
520 | 581 |
|
521 | 582 | # convert text |
|
0 commit comments