@@ -2401,7 +2401,7 @@ def __len__(self):
24012401 return len (self .audio_filenames )
24022402
24032403
2404- class DynamicTensor :
2404+ class DynamicLengthTensor :
24052405 def __init__ (
24062406 self ,
24072407 batch_size : int ,
@@ -2441,19 +2441,23 @@ def _allocate_more(self, min_add_length: int | None = None):
24412441 self .data = torch .cat ((self .data , self .data .new_zeros (add_shape )), dim = 1 )
24422442 self ._max_length += add_len
24432443
2444- def to_device (self , device : str | torch .device ):
2444+ def to_device (self , device : str | torch .device ) -> "DynamicLengthTensor" :
2445+ """Move storage to device"""
24452446 self .device = device
24462447 self .data .to (device = device )
24472448 self .lengths .to (device = device )
2449+ return self
24482450
24492451 def append_ (self , data : torch .Tensor , lengths : torch .Tensor | None = None ):
2452+ """Append new data along length dimension"""
24502453 cur_len = self .lengths .max ().item ()
24512454 other_len = data .shape [1 ] if lengths is None else lengths .max ().item ()
24522455 if cur_len + other_len >= self ._max_length :
24532456 self ._allocate_more (min_add_length = cur_len + other_len - self ._max_length + 1 )
24542457 self .append_no_checks_ (data = data [:, :other_len ], lengths = lengths )
24552458
24562459 def append_no_checks_ (self , data : torch .Tensor , lengths : torch .Tensor | None = None ):
2460+ """Append new data along length dimension without checks"""
24572461 other_len = data .shape [1 ]
24582462 indices = torch .arange (other_len , device = self .device )
24592463 shifted_indices = self .lengths [:, None ] + indices [None , :]
@@ -2463,9 +2467,9 @@ def append_no_checks_(self, data: torch.Tensor, lengths: torch.Tensor | None = N
24632467 else :
24642468 self .lengths += lengths
24652469
2466- def clone (self ) -> "DynamicTensor " :
2470+ def clone (self ) -> "DynamicLengthTensor " :
24672471 """Return a copy of self"""
2468- new_dynamic_tensor = DynamicTensor (
2472+ new_dynamic_tensor = DynamicLengthTensor (
24692473 batch_size = self .batch_size ,
24702474 init_length = self ._max_length ,
24712475 device = self .device ,
@@ -2475,13 +2479,13 @@ def clone(self) -> "DynamicTensor":
24752479 new_dynamic_tensor .data .copy_ (self .lengths )
24762480 return new_dynamic_tensor
24772481
2478- def merge_ (self , other : "DynamicTensor " ) -> "DynamicTensor " :
2482+ def merge_ (self , other : "DynamicLengthTensor " ) -> "DynamicLengthTensor " :
24792483 """
24802484 Merge two dynamic tensors
24812485 NB: this will reallocate memory
24822486
24832487 Args:
2484- other: DynamicTensor
2488+ other: DynamicLengthTensor
24852489 """
24862490 self .append_ (data = other .data , lengths = other .lengths )
24872491 return self
0 commit comments