Skip to content

Commit c72136c

Browse files
authored
Add AudioSamples(mono_downmix=True) to handle mixed single/multi channel batches gracefully (#1563)
* Add AudioSamples(mono_downmix=True) to handle mixed single/multi channel batches gracefully * Update defaults to be non-breaking for multi-channel audio
1 parent d289860 commit c72136c

4 files changed

Lines changed: 236 additions & 6 deletions

File tree

lhotse/dataset/collation.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ def collate_audio(
151151
executor: Optional[Executor] = None,
152152
fault_tolerant: bool = False,
153153
recording_field: Optional[str] = None,
154+
mono_downmix: Optional[bool] = None,
154155
) -> Union[
155156
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, CutSet]
156157
]:
157158
"""
158159
Load audio samples for all the cuts and return them as a batch in a torch tensor.
159-
The output shape is ``(batch, time)``.
160+
The output shape is ``(batch, time)`` or ``(batch, channels, time)``.
160161
The cuts will be padded with silence if necessary.
161162
162163
:param cuts: a :class:`CutSet` used to load the audio samples.
@@ -168,6 +169,13 @@ def collate_audio(
168169
where the third element is a CutSet for which the audio data were sucessfully read.
169170
:param recording_field: when specified, we will try to load recordings from a custom field with this name
170171
(i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``).
172+
:param mono_downmix: controls channel handling.
173+
``None`` (default): auto-detect — uses downmix semantics unless every cut in the batch
174+
is multichannel, in which case multichannel collation is used.
175+
``True``: multichannel audio is downmixed to mono by averaging channels; output shape
176+
is ``(batch, time)``.
177+
``False``: mono audio is placed in channel 0 with remaining channels zero-padded to
178+
match the batch maximum; output shape is ``(batch, channels, time)``.
171179
:return: a tuple of tensors ``(audio, audio_lens)``, or ``(audio, audio_lens, cuts)``.
172180
"""
173181
for cut in cuts:
@@ -204,11 +212,32 @@ def collate_audio(
204212
filter_aux_iter=sample_counts,
205213
)
206214

207-
if len(audios[0].shape) == 1:
208-
audios = collate_vectors(audios, padding_value=0.0)
215+
if mono_downmix is None:
216+
# Auto-detect: use False semantics only when every audio is multichannel
217+
mono_downmix = not all(a.dim() == 2 for a in audios)
218+
219+
if mono_downmix:
220+
# Downmix multichannel audio to mono by averaging channels
221+
processed = []
222+
for audio in audios:
223+
if audio.dim() == 2:
224+
audio = audio.mean(dim=0) # (channels, time) -> (time,)
225+
processed.append(audio)
226+
audios = collate_vectors(processed, padding_value=0.0)
209227
else:
228+
# Expand mono audio to match max channels in batch, then collate as multichannel
229+
max_channels = max(
230+
audio.shape[0] if audio.dim() == 2 else 1 for audio in audios
231+
)
232+
processed = []
233+
for audio in audios:
234+
if audio.dim() == 1:
235+
expanded = audio.new_zeros(max_channels, audio.shape[0])
236+
expanded[0] = audio
237+
audio = expanded
238+
processed.append(audio)
210239
audios = collate_matrices(
211-
[a.transpose(0, 1) for a in audios], padding_value=0.0
240+
[a.transpose(0, 1) for a in processed], padding_value=0.0
212241
).transpose(1, 2)
213242
audio_lens = torch.tensor(sample_counts, dtype=torch.int32)
214243

lhotse/dataset/input_strategies.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def __init__(
223223
fault_tolerant: bool = False,
224224
executor_type: Type[ExecutorType] = ThreadPoolExecutor,
225225
use_batch_loader: bool = False,
226+
mono_downmix: Optional[bool] = None,
226227
) -> None:
227228
"""
228229
AudioSamples constructor.
@@ -240,9 +241,14 @@ def __init__(
240241
:param use_batch_loader: When ``True``, enables batch loading of audio data from AIStore.
241242
This allows all audio samples in the batch to be fetched in a single request for increased efficiency.
242243
Requires the input CutSet to be eager (not lazy).
244+
:param mono_downmix: controls channel handling (passed to :func:`collate_audio`).
245+
``None`` (default): auto-detect — downmix unless every cut is multichannel.
246+
``True``: always downmix to mono; output shape is ``(B, T)``.
247+
``False``: expand mono to channel 0 with zero-padded channels; output shape is ``(B, C, T)``.
243248
"""
244249
super().__init__(num_workers=num_workers, executor_type=executor_type)
245250
self.fault_tolerant = fault_tolerant
251+
self.mono_downmix = mono_downmix
246252
self.ais_batch_loader = None
247253
self.use_batch_loader = use_batch_loader
248254
if self.use_batch_loader:
@@ -277,6 +283,7 @@ def __call__(
277283
executor=_get_executor(self.num_workers, executor_type=self._executor_type),
278284
fault_tolerant=self.fault_tolerant,
279285
recording_field=recording_field,
286+
mono_downmix=self.mono_downmix,
280287
)
281288

282289
def supervision_intervals(self, cuts: CutSet) -> Dict[str, torch.Tensor]:

test/dataset/test_batch_io.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import numpy as np
66
import pytest
7+
import torch
78
import torch.testing
89

910
from lhotse import CutSet, Fbank, MonoCut
1011
from lhotse.dataset import AudioSamples, OnTheFlyFeatures, PrecomputedFeatures
12+
from lhotse.testing.dummies import dummy_cut, dummy_multi_cut
1113

1214

1315
@pytest.fixture
@@ -79,3 +81,77 @@ def test_audio_samples_equivalent_to_cut_set_load_audio(libri_cut_set):
7981
def test_cut_set_load_audio_collate_false(libri_cut_set):
8082
audio = libri_cut_set.load_audio()
8183
assert isinstance(audio, list)
84+
85+
86+
def test_audio_samples_mono_downmix_none_mono_only():
87+
# None + all-mono -> True semantics -> (B, T)
88+
cuts = CutSet(
89+
[
90+
dummy_cut(0, duration=1.0, with_data=True),
91+
dummy_cut(1, duration=1.0, with_data=True),
92+
]
93+
)
94+
audio, _ = AudioSamples(mono_downmix=None)(cuts)
95+
assert audio.shape == (2, 16000)
96+
97+
98+
def test_audio_samples_mono_downmix_none_multi_only():
99+
# None + all-multi -> False semantics -> (B, C, T)
100+
cuts = CutSet(
101+
[
102+
dummy_multi_cut(0, channel=[0, 1], with_data=True),
103+
dummy_multi_cut(1, channel=[0, 1], with_data=True),
104+
]
105+
)
106+
audio, _ = AudioSamples(mono_downmix=None)(cuts)
107+
assert audio.shape == (2, 2, 16000)
108+
109+
110+
def test_audio_samples_mono_downmix_none_mixed():
111+
# None + mixed -> True semantics -> (B, T)
112+
cuts = CutSet(
113+
[
114+
dummy_cut(0, duration=1.0, with_data=True),
115+
dummy_multi_cut(1, channel=[0, 1], with_data=True),
116+
]
117+
)
118+
audio, _ = AudioSamples(mono_downmix=None)(cuts)
119+
assert audio.shape == (2, 16000)
120+
121+
122+
def test_audio_samples_mono_downmix_true_multichannel():
123+
# Multichannel batch downmixed to mono -> (B, T)
124+
cuts = CutSet(
125+
[
126+
dummy_multi_cut(0, channel=[0, 1], with_data=True),
127+
dummy_multi_cut(1, channel=[0, 1], with_data=True),
128+
]
129+
)
130+
batchio = AudioSamples(mono_downmix=True)
131+
audio, audio_lens = batchio(cuts)
132+
assert audio.shape == (2, 16000)
133+
134+
135+
def test_audio_samples_mono_downmix_false_mono_batch():
136+
# Mono batch with mono_downmix=False -> (B, 1, T)
137+
cuts = CutSet(
138+
[
139+
dummy_cut(0, duration=1.0, with_data=True),
140+
dummy_cut(1, duration=1.0, with_data=True),
141+
]
142+
)
143+
batchio = AudioSamples(mono_downmix=False)
144+
audio, audio_lens = batchio(cuts)
145+
assert audio.shape == (2, 1, 16000)
146+
147+
148+
def test_audio_samples_mono_downmix_false_mixed_batch():
149+
# Mixed batch: mono placed in ch0 with zeros in ch1 -> (B, 2, T)
150+
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
151+
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
152+
cuts = CutSet([cut_mono, cut_multi])
153+
batchio = AudioSamples(mono_downmix=False)
154+
audio, audio_lens = batchio(cuts)
155+
assert audio.shape == (2, 2, 16000)
156+
# Mono cut's channel 1 must be all zeros
157+
assert audio[0, 1, :].eq(0).all()

test/dataset/test_collation.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,12 @@ def test_collate_cut_multi_channel_recording_and_custom_recording_diff_num_chann
395395

396396
expected_lens = torch.tensor([16000, 32000], dtype=torch.int32)
397397

398-
audio, audio_lens = collate_audio(cuts)
398+
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
399399
assert audio.shape == (2, 4, 32000) # batch x channel x time
400400
torch.testing.assert_close(audio_lens, expected_lens)
401401

402402
target_audio, target_audio_lens = collate_audio(
403-
cuts, recording_field="target_recording"
403+
cuts, recording_field="target_recording", mono_downmix=False
404404
)
405405
assert target_audio.shape == (2, 2, 32000) # batch x channel x time
406406
torch.testing.assert_close(audio_lens, expected_lens)
@@ -442,3 +442,121 @@ def test_collate_custom_audio_works_despite_non_unique_ids():
442442
audio, audio_lens = collate_audio(cuts, recording_field="custom_recording")
443443
assert audio_lens.tolist() == [32000, 16000]
444444
assert audio.shape == (2, 32000)
445+
446+
447+
# --- mono_downmix tests ---
448+
449+
450+
def test_collate_audio_mono_downmix_true_all_mono():
451+
# Default behavior: all-mono batch stays (B, T)
452+
cuts = CutSet(
453+
[
454+
dummy_cut(0, duration=2.0, with_data=True),
455+
dummy_cut(1, duration=1.0, with_data=True),
456+
]
457+
)
458+
audio, audio_lens = collate_audio(cuts, mono_downmix=True)
459+
assert audio.shape == (2, 32000)
460+
assert audio_lens.tolist() == [32000, 16000]
461+
462+
463+
def test_collate_audio_mono_downmix_true_all_multichannel():
464+
# Multichannel audio should be downmixed to mono -> (B, T)
465+
cuts = CutSet(
466+
[
467+
dummy_multi_cut(0, channel=[0, 1, 2], with_data=True),
468+
dummy_multi_cut(1, channel=[0, 1, 2], with_data=True),
469+
]
470+
)
471+
audio, audio_lens = collate_audio(cuts, mono_downmix=True)
472+
assert audio.shape == (2, 16000)
473+
assert audio_lens.tolist() == [16000, 16000]
474+
475+
476+
def test_collate_audio_mono_downmix_true_mixed_batch():
477+
# Mixed batch: one mono, one multichannel -> downmix all to (B, T)
478+
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
479+
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
480+
cuts = CutSet([cut_mono, cut_multi])
481+
audio, audio_lens = collate_audio(cuts, mono_downmix=True)
482+
assert audio.shape == (2, 16000)
483+
assert audio_lens.tolist() == [16000, 16000]
484+
485+
486+
def test_collate_audio_mono_downmix_false_all_mono():
487+
# Mono-only batch: expand to (B, 1, T)
488+
cuts = CutSet(
489+
[
490+
dummy_cut(0, duration=2.0, with_data=True),
491+
dummy_cut(1, duration=1.0, with_data=True),
492+
]
493+
)
494+
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
495+
assert audio.shape == (2, 1, 32000)
496+
assert audio_lens.tolist() == [32000, 16000]
497+
498+
499+
def test_collate_audio_mono_downmix_false_all_multichannel():
500+
# All-multichannel batch collates as (B, C, T), same as existing behavior
501+
cuts = CutSet(
502+
[
503+
dummy_multi_cut(0, duration=2.0, channel=[0, 1], with_data=True),
504+
dummy_multi_cut(1, duration=1.0, channel=[0, 1], with_data=True),
505+
]
506+
)
507+
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
508+
assert audio.shape == (2, 2, 32000)
509+
assert audio_lens.tolist() == [32000, 16000]
510+
511+
512+
def test_collate_audio_mono_downmix_false_mixed_batch():
513+
# Mixed batch: mono is expanded to match max channels -> (B, C, T)
514+
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
515+
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
516+
cuts = CutSet([cut_mono, cut_multi])
517+
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
518+
assert audio.shape == (2, 2, 16000)
519+
assert audio_lens.tolist() == [16000, 16000]
520+
521+
522+
def test_collate_audio_mono_downmix_none_mono_only():
523+
# None + all-mono -> True semantics -> (B, T)
524+
cuts = CutSet(
525+
[
526+
dummy_cut(0, duration=1.0, with_data=True),
527+
dummy_cut(1, duration=1.0, with_data=True),
528+
]
529+
)
530+
audio, audio_lens = collate_audio(cuts, mono_downmix=None)
531+
assert audio.shape == (2, 16000)
532+
533+
534+
def test_collate_audio_mono_downmix_none_mixed():
535+
# None + mixed mono+multi -> True semantics -> (B, T)
536+
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
537+
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
538+
cuts = CutSet([cut_mono, cut_multi])
539+
audio, audio_lens = collate_audio(cuts, mono_downmix=None)
540+
assert audio.shape == (2, 16000)
541+
542+
543+
def test_collate_audio_mono_downmix_none_multi_only():
544+
# None + all-multi -> False semantics -> (B, C, T)
545+
cuts = CutSet(
546+
[
547+
dummy_multi_cut(0, channel=[0, 1], with_data=True),
548+
dummy_multi_cut(1, channel=[0, 1], with_data=True),
549+
]
550+
)
551+
audio, audio_lens = collate_audio(cuts, mono_downmix=None)
552+
assert audio.shape == (2, 2, 16000)
553+
554+
555+
def test_collate_audio_mono_downmix_false_mono_zero_padded_channels():
556+
# Mono is placed in channel 0; remaining channels are zero
557+
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
558+
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
559+
cuts = CutSet([cut_mono, cut_multi])
560+
audio, _ = collate_audio(cuts, mono_downmix=False)
561+
# audio[0] is from the mono cut: channel 1 must be all zeros
562+
assert audio[0, 1, :].eq(0).all()

0 commit comments

Comments
 (0)