Skip to content

Commit 711e707

Browse files
committed
Clean up
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
1 parent c6a6b2d commit 711e707

2 files changed

Lines changed: 13 additions & 9 deletions

File tree

examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from nemo.collections.asr.parts.utils.streaming_utils import (
8585
AudioBatch,
8686
ContextSize,
87-
DynamicTensor,
87+
DynamicLengthTensor,
8888
SimpleAudioDataset,
8989
StreamingBatchedAudioBuffer,
9090
)
@@ -421,7 +421,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
421421
device=device,
422422
)
423423
rest_audio_lengths = audio_batch_lengths.clone()
424-
encoder_output_aggregated: DynamicTensor | None = None
424+
encoder_output_aggregated: DynamicLengthTensor | None = None
425425

426426
# iterate over audio samples
427427
while left_sample < audio_batch.shape[1]:
@@ -461,7 +461,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
461461
if use_simulated_decoding:
462462
# store encoder output
463463
if encoder_output_aggregated is None:
464-
encoder_output_aggregated = DynamicTensor(
464+
encoder_output_aggregated = DynamicLengthTensor(
465465
batch_size=batch_size,
466466
init_length=encoder_output.shape[1],
467467
dim_shape=encoder_output.shape[2],

nemo/collections/asr/parts/utils/streaming_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,7 +2401,7 @@ def __len__(self):
24012401
return len(self.audio_filenames)
24022402

24032403

2404-
class DynamicTensor:
2404+
class DynamicLengthTensor:
24052405
def __init__(
24062406
self,
24072407
batch_size: int,
@@ -2441,19 +2441,23 @@ def _allocate_more(self, min_add_length: int | None = None):
24412441
self.data = torch.cat((self.data, self.data.new_zeros(add_shape)), dim=1)
24422442
self._max_length += add_len
24432443

2444-
def to_device(self, device: str | torch.device):
2444+
def to_device(self, device: str | torch.device) -> "DynamicLengthTensor":
2445+
"""Move storage to device"""
24452446
self.device = device
24462447
self.data.to(device=device)
24472448
self.lengths.to(device=device)
2449+
return self
24482450

24492451
def append_(self, data: torch.Tensor, lengths: torch.Tensor | None = None):
2452+
"""Append new data along length dimension"""
24502453
cur_len = self.lengths.max().item()
24512454
other_len = data.shape[1] if lengths is None else lengths.max().item()
24522455
if cur_len + other_len >= self._max_length:
24532456
self._allocate_more(min_add_length=cur_len + other_len - self._max_length + 1)
24542457
self.append_no_checks_(data=data[:, :other_len], lengths=lengths)
24552458

24562459
def append_no_checks_(self, data: torch.Tensor, lengths: torch.Tensor | None = None):
2460+
"""Append new data along length dimension without checks"""
24572461
other_len = data.shape[1]
24582462
indices = torch.arange(other_len, device=self.device)
24592463
shifted_indices = self.lengths[:, None] + indices[None, :]
@@ -2463,9 +2467,9 @@ def append_no_checks_(self, data: torch.Tensor, lengths: torch.Tensor | None = N
24632467
else:
24642468
self.lengths += lengths
24652469

2466-
def clone(self) -> "DynamicTensor":
2470+
def clone(self) -> "DynamicLengthTensor":
24672471
"""Return a copy of self"""
2468-
new_dynamic_tensor = DynamicTensor(
2472+
new_dynamic_tensor = DynamicLengthTensor(
24692473
batch_size=self.batch_size,
24702474
init_length=self._max_length,
24712475
device=self.device,
@@ -2475,13 +2479,13 @@ def clone(self) -> "DynamicTensor":
24752479
new_dynamic_tensor.data.copy_(self.lengths)
24762480
return new_dynamic_tensor
24772481

2478-
def merge_(self, other: "DynamicTensor") -> "DynamicTensor":
2482+
def merge_(self, other: "DynamicLengthTensor") -> "DynamicLengthTensor":
24792483
"""
24802484
Merge two dynamic tensors
24812485
NB: this will reallocate memory
24822486
24832487
Args:
2484-
other: DynamicTensor
2488+
other: DynamicLengthTensor
24852489
"""
24862490
self.append_(data=other.data, lengths=other.lengths)
24872491
return self

0 commit comments

Comments
 (0)