1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Dict , Optional
15+ import random
16+ from typing import Dict
1617
18+ import numpy as np
1719import torch
1820from lhotse import CutSet
19- from lhotse .dataset .collation import collate_audio
2021
2122from nemo .utils import logging
2223
@@ -28,33 +29,64 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset):
2829 It is a simple dataset that mostly just loads the audio samples.
2930 In addition, it performs the following operations:
3031 * Resampling to the target sample rate
32+ * Random truncation of each cut's `target_audio` to a fixed duration
3133 * Sanity checks on the audio
3234
3335 The operations below are handled directly by Lhotse according to the configuration
3436 applied in `AudioCodecModel._get_lhotse_dataloader()`:
35- * Duration filtering
37+ * Minimum duration filtering
3638 * Any additional transformations configured in Lhotse during its construction are
37- applied to the audio as it is loaded in `collate_audio ()`.
39+ applied to the audio as it is loaded in `load_audio ()`.
3840 """
3941
4042 def __init__ (
4143 self ,
4244 sample_rate : int ,
45+ segment_duration : float ,
4346 sanity_check_audio : bool = False ,
44- min_samples_for_sanity : Optional [int ] = None ,
4547 ):
4648 """
4749 Args:
4850 sample_rate: The sample rate to resample the audio to.
51+ segment_duration: Length of each training segment in seconds. A random
52+ segment of this length is taken from each cut's `target_audio` field
53+ (not from the parent `recording`, which may span a much longer duration).
4954 sanity_check_audio: If True, perform sanity checks on the loaded audio.
50- min_samples_for_sanity: cuts should have at least this many samples or an
51- error will be raised. Only used when
52- `sanity_check_audio` is True.
5355 """
5456 super ().__init__ ()
5557 self .sample_rate = sample_rate
58+ self .segment_duration = segment_duration
59+ self .segment_samples = int (segment_duration * sample_rate )
5660 self .sanity_check_audio = sanity_check_audio
57- self .min_samples_for_sanity = min_samples_for_sanity
61+ # Error out if audio is suspiciously short (leaving some slack for resampling).
62+ self .min_samples_for_sanity = max (1 , self .segment_samples - 5 )
63+
64+ def _load_and_truncate_target_audio (self , cut ) -> torch .Tensor :
65+ """
66+ Load `target_audio`, resample, and return a random segment of length `segment_duration`.
67+ """
68+ if not cut .has_custom ("target_audio" ):
69+ raise ValueError (f"Cut { cut .id } is missing custom field 'target_audio'" )
70+
71+ target_audio_recording = cut .target_audio .resample (self .sample_rate )
72+ # Load the target audio, resampling and applying and Lhotse transformation in the process
73+ audio = target_audio_recording .load_audio ()
74+ if audio .ndim > 1 :
75+ audio = audio .squeeze (0 )
76+
77+ num_samples = audio .shape [- 1 ]
78+ if num_samples < self .segment_samples :
79+ raise ValueError (
80+ f"target_audio is shorter than segment_duration: "
81+ f"cut_id={ cut .id } , target_audio_id={ target_audio_recording .id } , "
82+ f"num_samples={ num_samples } , required={ self .segment_samples } , "
83+ f"segment_duration={ self .segment_duration } s"
84+ )
85+
86+ # Randomly select a segment of the audio
87+ start = random .randint (0 , num_samples - self .segment_samples )
88+ segment = audio [start : start + self .segment_samples ]
89+ return torch .from_numpy (np .ascontiguousarray (segment , dtype = np .float32 ))
5890
5991 def __getitem__ (self , cuts : CutSet ) -> Dict [str , torch .Tensor ]:
6092 """
@@ -65,19 +97,15 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
6597 Returns:
6698 A dictionary with the `audio` and `audio_lens` tensors.
6799 """
68- # Resample the audio to the target sample rate. We need to do this manually
69- # because Lhotse only resamples its standard `recording` field automatically,
70- # not custom fields like `target_audio`.
71- for cut in cuts :
72- cut .target_audio = cut .target_audio .resample (self .sample_rate )
73-
74- # Load and collate the audio, applying any transformations that were
75- # configured in Lhotse in the process.
76- # Note: fault_tolerant=False for now to avoid masking errors until we are more
77- # confident in the new loader.
78- batch_audio , batch_audio_len = collate_audio (cuts , recording_field = "target_audio" , fault_tolerant = False )
79-
80- # Sanity checks on the audio and its length
100+ # Load, resample and truncate the audio
101+ audio_list = [self ._load_and_truncate_target_audio (cut ) for cut in cuts ]
102+ batch_audio = torch .stack (audio_list , dim = 0 )
103+ batch_audio_len = torch .full (
104+ (len (audio_list ),),
105+ self .segment_samples ,
106+ dtype = torch .int32 ,
107+ )
108+
81109 if self .sanity_check_audio :
82110 self ._sanity_check_audio (batch_audio , batch_audio_len , cuts )
83111
@@ -95,7 +123,7 @@ def _sanity_check_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, cuts
95123 # --- Error cases ---
96124
97125 # Audio length is unexpectedly short
98- if self . min_samples_for_sanity is not None and audio_len .min () < self .min_samples_for_sanity :
126+ if audio_len .min () < self .min_samples_for_sanity :
99127 raise ValueError (
100128 f"Audio length is less than { self .min_samples_for_sanity } samples (min: { audio_len .min ()} )"
101129 )
0 commit comments