Skip to content

Commit 907edfd

Browse files
authored
[Audio Codec] Lhotse data loading updates and fixes (#15742)
1. Move the random-segment-selection functionality from Lhotse to our dataset class, AudioCodecLhotseDataset. The corresponding built-in Lhotse functionality (truncate_duration) operates on the parent recording, which is not what we want. 2. Switch from batch_duration to batch_size for specifying the training batch size. In our setting, they are equivalent since the item size is fixed for all batch items, and it's clearer this way, now that segment selection is happening in the Dataset class.
1 parent 28a723f commit 907edfd

3 files changed

Lines changed: 108 additions & 39 deletions

File tree

nemo/collections/tts/data/audio_codec_dataset_lhotse.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Optional
15+
import random
16+
from typing import Dict
1617

18+
import numpy as np
1719
import torch
1820
from lhotse import CutSet
19-
from lhotse.dataset.collation import collate_audio
2021

2122
from nemo.utils import logging
2223

@@ -28,33 +29,64 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset):
2829
It is a simple dataset that mostly just loads the audio samples.
2930
In addition, it performs the following operations:
3031
* Resampling to the target sample rate
32+
* Random truncation of each cut's `target_audio` to a fixed duration
3133
* Sanity checks on the audio
3234
3335
The operations below are handled directly by Lhotse according to the configuration
3436
applied in `AudioCodecModel._get_lhotse_dataloader()`:
35-
* Duration filtering
37+
* Minimum duration filtering
3638
* Any additional transformations configured in Lhotse during its construction are
37-
applied to the audio as it is loaded in `collate_audio()`.
39+
applied to the audio as it is loaded in `load_audio()`.
3840
"""
3941

4042
def __init__(
4143
self,
4244
sample_rate: int,
45+
segment_duration: float,
4346
sanity_check_audio: bool = False,
44-
min_samples_for_sanity: Optional[int] = None,
4547
):
4648
"""
4749
Args:
4850
sample_rate: The sample rate to resample the audio to.
51+
segment_duration: Length of each training segment in seconds. A random
52+
segment of this length is taken from each cut's `target_audio` field
53+
(not from the parent `recording`, which may span a much longer duration).
4954
sanity_check_audio: If True, perform sanity checks on the loaded audio.
50-
min_samples_for_sanity: cuts should have at least this many samples or an
51-
error will be raised. Only used when
52-
`sanity_check_audio` is True.
5355
"""
5456
super().__init__()
5557
self.sample_rate = sample_rate
58+
self.segment_duration = segment_duration
59+
self.segment_samples = int(segment_duration * sample_rate)
5660
self.sanity_check_audio = sanity_check_audio
57-
self.min_samples_for_sanity = min_samples_for_sanity
61+
# Error out if audio is suspiciously short (leaving some slack for resampling).
62+
self.min_samples_for_sanity = max(1, self.segment_samples - 5)
63+
64+
def _load_and_truncate_target_audio(self, cut) -> torch.Tensor:
65+
"""
66+
Load `target_audio`, resample, and return a random segment of length `segment_duration`.
67+
"""
68+
if not cut.has_custom("target_audio"):
69+
raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'")
70+
71+
target_audio_recording = cut.target_audio.resample(self.sample_rate)
72+
# Load the target audio, resampling and applying and Lhotse transformation in the process
73+
audio = target_audio_recording.load_audio()
74+
if audio.ndim > 1:
75+
audio = audio.squeeze(0)
76+
77+
num_samples = audio.shape[-1]
78+
if num_samples < self.segment_samples:
79+
raise ValueError(
80+
f"target_audio is shorter than segment_duration: "
81+
f"cut_id={cut.id}, target_audio_id={target_audio_recording.id}, "
82+
f"num_samples={num_samples}, required={self.segment_samples}, "
83+
f"segment_duration={self.segment_duration}s"
84+
)
85+
86+
# Randomly select a segment of the audio
87+
start = random.randint(0, num_samples - self.segment_samples)
88+
segment = audio[start : start + self.segment_samples]
89+
return torch.from_numpy(np.ascontiguousarray(segment, dtype=np.float32))
5890

5991
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
6092
"""
@@ -65,19 +97,15 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
6597
Returns:
6698
A dictionary with the `audio` and `audio_lens` tensors.
6799
"""
68-
# Resample the audio to the target sample rate. We need to do this manually
69-
# because Lhotse only resamples its standard `recording` field automatically,
70-
# not custom fields like `target_audio`.
71-
for cut in cuts:
72-
cut.target_audio = cut.target_audio.resample(self.sample_rate)
73-
74-
# Load and collate the audio, applying any transformations that were
75-
# configured in Lhotse in the process.
76-
# Note: fault_tolerant=False for now to avoid masking errors until we are more
77-
# confident in the new loader.
78-
batch_audio, batch_audio_len = collate_audio(cuts, recording_field="target_audio", fault_tolerant=False)
79-
80-
# Sanity checks on the audio and its length
100+
# Load, resample and truncate the audio
101+
audio_list = [self._load_and_truncate_target_audio(cut) for cut in cuts]
102+
batch_audio = torch.stack(audio_list, dim=0)
103+
batch_audio_len = torch.full(
104+
(len(audio_list),),
105+
self.segment_samples,
106+
dtype=torch.int32,
107+
)
108+
81109
if self.sanity_check_audio:
82110
self._sanity_check_audio(batch_audio, batch_audio_len, cuts)
83111

@@ -95,7 +123,7 @@ def _sanity_check_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, cuts
95123
# --- Error cases ---
96124

97125
# Audio length is unexpectedly short
98-
if self.min_samples_for_sanity is not None and audio_len.min() < self.min_samples_for_sanity:
126+
if audio_len.min() < self.min_samples_for_sanity:
99127
raise ValueError(
100128
f"Audio length is less than {self.min_samples_for_sanity} samples (min: {audio_len.min()})"
101129
)

nemo/collections/tts/models/audio_codec.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -854,23 +854,35 @@ def _get_lhotse_dataloader(self, cfg):
854854
# manually in the dataset class.
855855
loader_cfg.sample_rate = self.output_sample_rate
856856

857-
# Set up cut truncation, filtering, and random selection:
858-
# `truncate_duration` and `truncate_offset_type` are interpreted by Lhotse.
859-
# Together, they configure Lhotse to choose a random segment of this length
860-
# from each cut.
861-
if loader_cfg.truncate_duration is None:
862-
raise ValueError("`truncate_duration` must be set in the config")
863-
loader_cfg.truncate_offset_type = "random"
864-
# Also filter examples to be at least this long to avoid zero-padding
865-
loader_cfg.min_duration = loader_cfg.truncate_duration
857+
# Random segment selection is done in AudioCodecLhotseDataset on `target_audio`, not via
858+
# Lhotse's `truncate_duration` config (which operates on the parent recording).
859+
if cfg.dataloader_params.get("truncate_duration") is not None:
860+
raise ValueError(
861+
"`truncate_duration` must not be set in `train_ds.dataloader_params`; "
862+
"segment extraction is handled in `AudioCodecLhotseDataset` via `segment_duration`."
863+
)
864+
segment_duration = dataset_args.get("segment_duration")
865+
if segment_duration is None:
866+
raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ")
867+
existing_min_duration = cfg.dataloader_params.get("min_duration")
868+
if existing_min_duration is not None and existing_min_duration != -1:
869+
raise ValueError(
870+
"`min_duration` must not be set in `train_ds.dataloader_params`; "
871+
"it is set automatically from `train_ds.dataset_args.segment_duration`."
872+
)
873+
# Pre-filter to only include cuts whose parent recording is at least as long as
874+
# the training segment duration so the dataset class has enough samples to choose from.
875+
loader_cfg.min_duration = segment_duration
876+
877+
# Make sure batch_size is set
878+
if loader_cfg.batch_size is None:
879+
raise ValueError("`batch_size` must be set in `train_ds.dataloader_params`.")
866880

867881
# --- Create the dataset ---
868882

869-
# Error out if the audio is suspiciously short (half the expected length)
870-
min_samples_for_sanity = loader_cfg.truncate_duration * self.output_sample_rate // 2
871-
# Create the dataset
872883
dataset = AudioCodecLhotseDataset(
873-
sample_rate=self.output_sample_rate, min_samples_for_sanity=min_samples_for_sanity, **dataset_args
884+
sample_rate=self.output_sample_rate,
885+
**dataset_args,
874886
)
875887

876888
# Create the dataloader

tests/collections/tts/data/test_audio_codec_dataset_lhotse.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,16 @@ def cutset(tmp_path) -> CutSet:
8888
def dataset() -> AudioCodecLhotseDataset:
8989
return AudioCodecLhotseDataset(
9090
sample_rate=TARGET_SAMPLE_RATE,
91-
min_samples_for_sanity=DEFAULT_DURATION * TARGET_SAMPLE_RATE,
91+
segment_duration=DEFAULT_DURATION,
9292
)
9393

9494

9595
class TestAudioCodecLhotseDataset:
9696
@pytest.mark.unit
9797
def test_init(self):
98-
ds = AudioCodecLhotseDataset(sample_rate=22050, min_samples_for_sanity=512)
98+
ds = AudioCodecLhotseDataset(sample_rate=22050, segment_duration=1.0)
9999
assert ds.sample_rate == 22050
100-
assert ds.min_samples_for_sanity == 512
100+
assert ds.min_samples_for_sanity == 22050 - 5
101101

102102
@pytest.mark.unit
103103
def test_getitem_returns_expected_keys_and_shapes(self, dataset, cutset):
@@ -137,3 +137,32 @@ def test_getitem_resampling_preserves_frequency(self, dataset, cutset):
137137
# FFT bin width is TARGET_SAMPLE_RATE / n; allow ~1 bin of tolerance.
138138
bin_width_hz = TARGET_SAMPLE_RATE / n
139139
assert abs(peak_freq_hz - cutset[i].target_tone_frequency) <= bin_width_hz
140+
141+
@pytest.mark.unit
142+
def test_getitem_extracts_subset_of_longer_audio(self, tmp_path, dataset, monkeypatch):
143+
# A cut longer than segment_duration should yield a segment of exactly segment_samples
144+
# that is a contiguous slice taken from inside the longer source signal.
145+
# Use the target sample rate as the source rate so no resampling is involved.
146+
cut = _make_cut(tmp_path, "long", duration=3.0, sample_rate=TARGET_SAMPLE_RATE, tone_frequency=440.0)
147+
cuts = CutSet.from_cuts([cut])
148+
149+
# Load the full target audio the same way the dataset does (no resampling needed).
150+
full = cut.target_audio.load_audio().squeeze(0)
151+
segment_samples = int(DEFAULT_DURATION * TARGET_SAMPLE_RATE)
152+
153+
# Pin the random start so we can compare against the exact source slice.
154+
fixed_start = 7000
155+
monkeypatch.setattr(
156+
"nemo.collections.tts.data.audio_codec_dataset_lhotse.random.randint",
157+
lambda low, high: fixed_start,
158+
)
159+
160+
batch = dataset[cuts]
161+
segment = batch["audio"][0].numpy()
162+
163+
assert segment.shape == (segment_samples,)
164+
assert batch["audio_lens"][0].item() == segment_samples
165+
# Exact match holds only because source rate == target rate (no resampling) and the
166+
# dataset currently applies no augmentation. Once we add augmentation it makes
167+
# sense to remove this assertion.
168+
assert np.allclose(segment, full[fixed_start : fixed_start + segment_samples])

0 commit comments

Comments
 (0)