Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions lhotse/dataset/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ def collate_audio(
executor: Optional[Executor] = None,
fault_tolerant: bool = False,
recording_field: Optional[str] = None,
mono_downmix: Optional[bool] = None,
) -> Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, CutSet]
]:
"""
Load audio samples for all the cuts and return them as a batch in a torch tensor.
The output shape is ``(batch, time)``.
The output shape is ``(batch, time)`` or ``(batch, channels, time)``.
The cuts will be padded with silence if necessary.

:param cuts: a :class:`CutSet` used to load the audio samples.
Expand All @@ -168,6 +169,13 @@ def collate_audio(
where the third element is a CutSet for which the audio data were sucessfully read.
:param recording_field: when specified, we will try to load recordings from a custom field with this name
(i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``).
:param mono_downmix: controls channel handling.
``None`` (default): auto-detect — uses downmix semantics unless every cut in the batch
is multichannel, in which case multichannel collation is used.
``True``: multichannel audio is downmixed to mono by averaging channels; output shape
is ``(batch, time)``.
``False``: mono audio is placed in channel 0 with remaining channels zero-padded to
match the batch maximum; output shape is ``(batch, channels, time)``.
:return: a tuple of tensors ``(audio, audio_lens)``, or ``(audio, audio_lens, cuts)``.
"""
for cut in cuts:
Expand Down Expand Up @@ -204,11 +212,32 @@ def collate_audio(
filter_aux_iter=sample_counts,
)

if len(audios[0].shape) == 1:
audios = collate_vectors(audios, padding_value=0.0)
if mono_downmix is None:
# Auto-detect: use False semantics only when every audio is multichannel
mono_downmix = not all(a.dim() == 2 for a in audios)

if mono_downmix:
# Downmix multichannel audio to mono by averaging channels
processed = []
for audio in audios:
if audio.dim() == 2:
audio = audio.mean(dim=0) # (channels, time) -> (time,)
processed.append(audio)
audios = collate_vectors(processed, padding_value=0.0)
else:
# Expand mono audio to match max channels in batch, then collate as multichannel
max_channels = max(
audio.shape[0] if audio.dim() == 2 else 1 for audio in audios
)
processed = []
for audio in audios:
if audio.dim() == 1:
expanded = audio.new_zeros(max_channels, audio.shape[0])
expanded[0] = audio
audio = expanded
processed.append(audio)
audios = collate_matrices(
[a.transpose(0, 1) for a in audios], padding_value=0.0
[a.transpose(0, 1) for a in processed], padding_value=0.0
).transpose(1, 2)
audio_lens = torch.tensor(sample_counts, dtype=torch.int32)

Expand Down
7 changes: 7 additions & 0 deletions lhotse/dataset/input_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
fault_tolerant: bool = False,
executor_type: Type[ExecutorType] = ThreadPoolExecutor,
use_batch_loader: bool = False,
mono_downmix: Optional[bool] = None,
) -> None:
"""
AudioSamples constructor.
Expand All @@ -240,9 +241,14 @@ def __init__(
:param use_batch_loader: When ``True``, enables batch loading of audio data from AIStore.
This allows all audio samples in the batch to be fetched in a single request for increased efficiency.
Requires the input CutSet to be eager (not lazy).
:param mono_downmix: controls channel handling (passed to :func:`collate_audio`).
``None`` (default): auto-detect — downmix unless every cut is multichannel.
``True``: always downmix to mono; output shape is ``(B, T)``.
``False``: expand mono to channel 0 with zero-padded channels; output shape is ``(B, C, T)``.
"""
super().__init__(num_workers=num_workers, executor_type=executor_type)
self.fault_tolerant = fault_tolerant
self.mono_downmix = mono_downmix
self.ais_batch_loader = None
self.use_batch_loader = use_batch_loader
if self.use_batch_loader:
Expand Down Expand Up @@ -277,6 +283,7 @@ def __call__(
executor=_get_executor(self.num_workers, executor_type=self._executor_type),
fault_tolerant=self.fault_tolerant,
recording_field=recording_field,
mono_downmix=self.mono_downmix,
)

def supervision_intervals(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
Expand Down
76 changes: 76 additions & 0 deletions test/dataset/test_batch_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import numpy as np
import pytest
import torch
import torch.testing

from lhotse import CutSet, Fbank, MonoCut
from lhotse.dataset import AudioSamples, OnTheFlyFeatures, PrecomputedFeatures
from lhotse.testing.dummies import dummy_cut, dummy_multi_cut


@pytest.fixture
Expand Down Expand Up @@ -79,3 +81,77 @@ def test_audio_samples_equivalent_to_cut_set_load_audio(libri_cut_set):
def test_cut_set_load_audio_collate_false(libri_cut_set):
audio = libri_cut_set.load_audio()
assert isinstance(audio, list)


def test_audio_samples_mono_downmix_none_mono_only():
# None + all-mono -> True semantics -> (B, T)
cuts = CutSet(
[
dummy_cut(0, duration=1.0, with_data=True),
dummy_cut(1, duration=1.0, with_data=True),
]
)
audio, _ = AudioSamples(mono_downmix=None)(cuts)
assert audio.shape == (2, 16000)


def test_audio_samples_mono_downmix_none_multi_only():
# None + all-multi -> False semantics -> (B, C, T)
cuts = CutSet(
[
dummy_multi_cut(0, channel=[0, 1], with_data=True),
dummy_multi_cut(1, channel=[0, 1], with_data=True),
]
)
audio, _ = AudioSamples(mono_downmix=None)(cuts)
assert audio.shape == (2, 2, 16000)


def test_audio_samples_mono_downmix_none_mixed():
# None + mixed -> True semantics -> (B, T)
cuts = CutSet(
[
dummy_cut(0, duration=1.0, with_data=True),
dummy_multi_cut(1, channel=[0, 1], with_data=True),
]
)
audio, _ = AudioSamples(mono_downmix=None)(cuts)
assert audio.shape == (2, 16000)


def test_audio_samples_mono_downmix_true_multichannel():
# Multichannel batch downmixed to mono -> (B, T)
cuts = CutSet(
[
dummy_multi_cut(0, channel=[0, 1], with_data=True),
dummy_multi_cut(1, channel=[0, 1], with_data=True),
]
)
batchio = AudioSamples(mono_downmix=True)
audio, audio_lens = batchio(cuts)
assert audio.shape == (2, 16000)


def test_audio_samples_mono_downmix_false_mono_batch():
# Mono batch with mono_downmix=False -> (B, 1, T)
cuts = CutSet(
[
dummy_cut(0, duration=1.0, with_data=True),
dummy_cut(1, duration=1.0, with_data=True),
]
)
batchio = AudioSamples(mono_downmix=False)
audio, audio_lens = batchio(cuts)
assert audio.shape == (2, 1, 16000)


def test_audio_samples_mono_downmix_false_mixed_batch():
# Mixed batch: mono placed in ch0 with zeros in ch1 -> (B, 2, T)
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
cuts = CutSet([cut_mono, cut_multi])
batchio = AudioSamples(mono_downmix=False)
audio, audio_lens = batchio(cuts)
assert audio.shape == (2, 2, 16000)
# Mono cut's channel 1 must be all zeros
assert audio[0, 1, :].eq(0).all()
122 changes: 120 additions & 2 deletions test/dataset/test_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,12 @@ def test_collate_cut_multi_channel_recording_and_custom_recording_diff_num_chann

expected_lens = torch.tensor([16000, 32000], dtype=torch.int32)

audio, audio_lens = collate_audio(cuts)
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
assert audio.shape == (2, 4, 32000) # batch x channel x time
torch.testing.assert_close(audio_lens, expected_lens)

target_audio, target_audio_lens = collate_audio(
cuts, recording_field="target_recording"
cuts, recording_field="target_recording", mono_downmix=False
)
assert target_audio.shape == (2, 2, 32000) # batch x channel x time
torch.testing.assert_close(audio_lens, expected_lens)
Expand Down Expand Up @@ -442,3 +442,121 @@ def test_collate_custom_audio_works_despite_non_unique_ids():
audio, audio_lens = collate_audio(cuts, recording_field="custom_recording")
assert audio_lens.tolist() == [32000, 16000]
assert audio.shape == (2, 32000)


# --- mono_downmix tests ---


def test_collate_audio_mono_downmix_true_all_mono():
# Default behavior: all-mono batch stays (B, T)
cuts = CutSet(
[
dummy_cut(0, duration=2.0, with_data=True),
dummy_cut(1, duration=1.0, with_data=True),
]
)
audio, audio_lens = collate_audio(cuts, mono_downmix=True)
assert audio.shape == (2, 32000)
assert audio_lens.tolist() == [32000, 16000]


def test_collate_audio_mono_downmix_true_all_multichannel():
# Multichannel audio should be downmixed to mono -> (B, T)
cuts = CutSet(
[
dummy_multi_cut(0, channel=[0, 1, 2], with_data=True),
dummy_multi_cut(1, channel=[0, 1, 2], with_data=True),
]
)
audio, audio_lens = collate_audio(cuts, mono_downmix=True)
assert audio.shape == (2, 16000)
assert audio_lens.tolist() == [16000, 16000]


def test_collate_audio_mono_downmix_true_mixed_batch():
# Mixed batch: one mono, one multichannel -> downmix all to (B, T)
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
cuts = CutSet([cut_mono, cut_multi])
audio, audio_lens = collate_audio(cuts, mono_downmix=True)
assert audio.shape == (2, 16000)
assert audio_lens.tolist() == [16000, 16000]


def test_collate_audio_mono_downmix_false_all_mono():
# Mono-only batch: expand to (B, 1, T)
cuts = CutSet(
[
dummy_cut(0, duration=2.0, with_data=True),
dummy_cut(1, duration=1.0, with_data=True),
]
)
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
assert audio.shape == (2, 1, 32000)
assert audio_lens.tolist() == [32000, 16000]


def test_collate_audio_mono_downmix_false_all_multichannel():
# All-multichannel batch collates as (B, C, T), same as existing behavior
cuts = CutSet(
[
dummy_multi_cut(0, duration=2.0, channel=[0, 1], with_data=True),
dummy_multi_cut(1, duration=1.0, channel=[0, 1], with_data=True),
]
)
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
assert audio.shape == (2, 2, 32000)
assert audio_lens.tolist() == [32000, 16000]


def test_collate_audio_mono_downmix_false_mixed_batch():
# Mixed batch: mono is expanded to match max channels -> (B, C, T)
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
cuts = CutSet([cut_mono, cut_multi])
audio, audio_lens = collate_audio(cuts, mono_downmix=False)
assert audio.shape == (2, 2, 16000)
assert audio_lens.tolist() == [16000, 16000]


def test_collate_audio_mono_downmix_none_mono_only():
# None + all-mono -> True semantics -> (B, T)
cuts = CutSet(
[
dummy_cut(0, duration=1.0, with_data=True),
dummy_cut(1, duration=1.0, with_data=True),
]
)
audio, audio_lens = collate_audio(cuts, mono_downmix=None)
assert audio.shape == (2, 16000)


def test_collate_audio_mono_downmix_none_mixed():
# None + mixed mono+multi -> True semantics -> (B, T)
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
cuts = CutSet([cut_mono, cut_multi])
audio, audio_lens = collate_audio(cuts, mono_downmix=None)
assert audio.shape == (2, 16000)


def test_collate_audio_mono_downmix_none_multi_only():
# None + all-multi -> False semantics -> (B, C, T)
cuts = CutSet(
[
dummy_multi_cut(0, channel=[0, 1], with_data=True),
dummy_multi_cut(1, channel=[0, 1], with_data=True),
]
)
audio, audio_lens = collate_audio(cuts, mono_downmix=None)
assert audio.shape == (2, 2, 16000)


def test_collate_audio_mono_downmix_false_mono_zero_padded_channels():
# Mono is placed in channel 0; remaining channels are zero
cut_mono = dummy_cut(0, duration=1.0, with_data=True)
cut_multi = dummy_multi_cut(1, channel=[0, 1], with_data=True)
cuts = CutSet([cut_mono, cut_multi])
audio, _ = collate_audio(cuts, mono_downmix=False)
# audio[0] is from the mono cut: channel 1 must be all zeros
assert audio[0, 1, :].eq(0).all()
Loading