Skip to content

Commit c0eece3

Browse files
minor fix
Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>
1 parent 5c750ef commit c0eece3

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,8 +1493,8 @@ def _create_decoding_state(
14931493
"""Create BatchedBeamState for the next chunk."""
14941494
current_batch_size = encoder_output_length.shape[0]
14951495

1496-
# Get last labels from batched_hyps
1497-
last_labels = self.state.batched_hyps.get_last_labels(pad_id=self._SOS)
1496+
# Get last labels and slice to real batch (graph state's batch dim is sized to capture-time max).
1497+
last_labels = self.state.batched_hyps.get_last_labels(pad_id=self._SOS)[:current_batch_size]
14981498

14991499
# Snapshot per-beam pending skip (TDT ``time_jumps`` analogue) BEFORE resetting
15001500
# ``next_timestamp``; clamp at 0 for padded/never-emitted beams.
@@ -1514,7 +1514,11 @@ def _create_decoding_state(
15141514

15151515
# Handle labels - if nothing decoded this chunk, use previous labels
15161516
if prev_batched_state is not None:
1517-
last_labels = torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels)
1517+
last_labels = torch.where(
1518+
last_labels == self._SOS,
1519+
prev_batched_state.labels[:current_batch_size],
1520+
last_labels,
1521+
)
15181522

15191523
# Get fusion states if present
15201524
fusion_states_list = None

0 commit comments

Comments
 (0)