Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft of example integration with Lhotse #102

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions examples/audio/lhotse_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
#!/usr/bin/env python

"""
This is a working example of using a new PyTorch dataloading API called DataPipes
with Lhotse. It's not the final version and probably a lot of things will change.

Lhotse is a speech data representation library that brings various corpora into
a single format and interfaces well with PyTorch. See: https://github.com/lhotse-speech/lhotse

Prepare the env like this first:

$ conda create -n torchdata python=3.8
$ conda activate torchdata
$ conda install numpy
$ pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
$ pip install --user "git+https://github.com/pytorch/data.git"

Optionally, if not in this Lhotse branch already:

$ pip install git+https://github.com/lhotse-speech/lhotse@feature/datapipes-prototyping

"""


import warnings
from collections import deque, defaultdict
from functools import partial
from pathlib import Path
from typing import Optional

from lhotse import CutSet, load_manifest
from lhotse.utils import Seconds
from lhotse.recipes import download_librispeech, prepare_librispeech

import torchdata
import torch
import torchdata.datapipes.iter as dp # Only use Iter-style DataPipe to illustrate the pipeline

from torch.utils.data.datapipes.iter import Multiplexer, Demultiplexer
from torch.utils.data.communication.eventloop import SpawnProcessForDataPipeline
from torch.utils.data.communication.protocol import IterDataPipeQueueProtocolClient
from torch.utils.data.communication.iter import QueueWrapper


class CutsReader(dp.IterDataPipe):
def __init__(self, path) -> None:
self.cuts = CutSet.from_jsonl_lazy(path)

def __iter__(self):
yield from self.cuts


class DurationBatcher(dp.IterDataPipe):
def __init__(
self,
datapipe,
max_frames: int = None,
max_samples: int = None,
max_duration: Seconds = None,
max_cuts: Optional[int] = None,
drop_last: bool = False,
):
from lhotse.dataset.sampling.base import SamplingDiagnostics, TimeConstraint

self.datapipe = datapipe
self.reuse_cuts_buffer = deque()
self.drop_last = drop_last
self.max_cuts = max_cuts
self.diagnostics = SamplingDiagnostics()
self.time_constraint = TimeConstraint(
max_duration=max_duration, max_frames=max_frames, max_samples=max_samples
)

def __iter__(self):
self.datapipe = iter(self.datapipe)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not remove the reference to prior/source datapipe here. Since we want to let DataLoaderV2 have graph mode execution of data pipeline (run parallel, automatic sharding, etc.), we need to keep reference graph of each DataPipe instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also prevent you do second iteration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. DataLoaderV2 sounds interesting -- is there a description somewhere that I could read to learn more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't started to write all the related docs as the API and features keep changing. I can give some code pointers to the things I talked about.

while True:
yield self._collect_batch()
Comment on lines +76 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should reset the iter to None, because generator instance is not serializable (pickable).
IIUC, this is going to iterate forever. Is this intentionally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply speaking.

Suggested change
while True:
yield self._collect_batch()
def __iter__(self):
self.datapipe_iter = iter(self.datapipe)
...
self.datapipe_iter = None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is going to iterate forever. Is this intentionally?

This should stop once self._collect_batch() raises StopIteration. Maybe the sampling code in DurationBatcher can be made simpler -- as you can see there is a lot of branching. I'll give it more thought at a later point.


def _collect_batch(self):
self.time_constraint.reset()
cuts = []
while True:
# Check that we have not reached the end of the dataset.
try:
if self.reuse_cuts_buffer:
next_cut = self.reuse_cuts_buffer.popleft()
else:
# If this doesn't raise (typical case), it's not the end: keep processing.
next_cut = next(self.datapipe)
except StopIteration:
# No more cuts to sample from: if we have a partial batch,
# we may output it, unless the user requested to drop it.
# We also check if the batch is "almost there" to override drop_last.
if cuts and (
not self.drop_last or self.time_constraint.close_to_exceeding()
):
# We have a partial batch and we can return it.
self.diagnostics.keep(cuts)
return CutSet.from_cuts(cuts)
else:
# There is nothing more to return or it's discarded:
# signal the iteration code to stop.
self.diagnostics.discard(cuts)
raise StopIteration()

# Track the duration/frames/etc. constraints.
self.time_constraint.add(next_cut)
next_num_cuts = len(cuts) + 1

# Did we exceed the max_frames and max_cuts constraints?
if not self.time_constraint.exceeded() and (
self.max_cuts is None or next_num_cuts <= self.max_cuts
):
# No - add the next cut to the batch, and keep trying.
cuts.append(next_cut)
else:
# Yes. Do we have at least one cut in the batch?
if cuts:
# Yes. Return the batch, but keep the currently drawn cut for later.
self.reuse_cuts_buffer.append(next_cut)
break
else:
# No. We'll warn the user that the constrains might be too tight,
# and return the cut anyway.
warnings.warn(
"The first cut drawn in batch collection violates "
"the max_frames, max_cuts, or max_duration constraints - "
"we'll return it anyway. "
"Consider increasing max_frames/max_cuts/max_duration."
)
cuts.append(next_cut)

self.diagnostics.keep(cuts)
return CutSet.from_cuts(cuts)


class IODataPipe(dp.IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for cut_idx, batch_idx, batch_size, cut in self.datapipe:
yield cut_idx, batch_idx, batch_size, cut.load_audio(), cut
Comment on lines +137 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use Mapper to replace this DataPipe?
You should be able to use functional API as:

source_datapipe # Referring to self.datapipe
source_datapipe.map(fn=lambda x: x.load_audio(), input_col=3, output_col=-1)

The result is going to be cut_idx, batch_idx, batch_size, cut, cut.load_audio()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks interesting. Will the lambda work well with pickling / multiprocessing though?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: is it possible to work with dicts rather than tuples when using the functional API? E.g.:

yield {
  'cut_idx': cut_idx,
  'batch_idx': batch_idx,
  ...
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the lambda work well with pickling / multiprocessing though?

We want to support lambda as it's a common feature for functional programming. So, as long as dill is installed, we would let multiprocessing with lambda works.

Also: is it possible to work with dicts rather than tuples when using the functional API?

Yes. It's supposed to work if you give input_col='cut_idx'.



class RecombineBatchAfterIO(dp.IterDataPipe):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting use case. It's better to have a buffer size IMO. Otherwise, it potentially blows up memory.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you envision it? Let's say buffer size is 500 items but for some reason, I sampled a batch that has 600 items. Wouldn't that hang indefinitely?

I might not understand the new dataloading machinery yet -- my assumption was that there is some mechanism to make it prepare a prefetch_factor number of batches, and if the consumer stops fetching them (e.g., because it's training), then the whole pipeline would "stop" and wait. If it doesn't stop, I can see how the memory might blow up though.

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
# This is a buffer that will be re-combining batches.
# We might possibly get cuts from the same batch out-of-order.
batches: Dict[int, List[Cut]] = defaultdict(list)
for cut_idx, batch_idx, batch_size, audio, cut in self.datapipe:
batches[batch_idx].append((audio, cut))
if len(batches[batch_idx]) == batch_size:
audios, cuts = zip(*batches[batch_idx])
yield audios, CutSet.from_cuts(cuts)
# Free up the buffer dict after using the batch.
del batches[batch_idx]


class UnbatchForIO(dp.IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for batch_idx, cuts_batch in enumerate(self.datapipe):
batch_size = len(cuts_batch)
for cut_idx, cut in enumerate(cuts_batch):
yield cut_idx, batch_idx, batch_size, cut


def classifier_fn(value, num_jobs):
return value[0] % num_jobs


def build_pipeline(path: str, num_jobs: int = 2):

# Open cuts manifest for lazy reading
datapipe = CutsReader(path)

# Custom Sampling
datapipe = DurationBatcher(datapipe, max_duration=100)

# Unbatch the batched to run I/O
datapipe = UnbatchForIO(datapipe) # Yield (batch index, dataclass)

# The following Multiprocessing API may change, and multiprocessing should be added and handled by DataLoader
# demux and mux would be called by DataLoader. But, I want to explicitly call them to illustrate the functionality

# Routed prior datapipe to ``num_jobs`` datapipes in round-robin manner
datapipes = datapipe.demux(
num_jobs, classifier_fn=partial(classifier_fn, num_jobs=num_jobs)
)
for i in range(len(datapipes)):
datapipes[i] = IODataPipe(datapipes[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much. This is a great example to show the functionality to distribute extensive operation to multiprocessing.

import multiprocessing

ctx = multiprocessing.get_context("spawn")
for i in range(num_jobs):
process, req_queue, res_queue = SpawnProcessForDataPipeline(ctx, datapipes[i])
process.start()
datapipes[i] = QueueWrapper(
IterDataPipeQueueProtocolClient(req_queue, res_queue)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit the QueueWrapper and IterDataPipeQueueProtocolClient took me a while to get right. I feel like I'm unnecessarily "hacking" the library and using a too low-level API not intended for users. Is there a cleaner way to achieve the same outcome?

)
# DataLoader will also handle stop, join and cleanup subprocesses

datapipe = dp.Multiplexer(*datapipes)

datapipe = RecombineBatchAfterIO(datapipe)
# datapipe = datapipe.collate().tranform(transforms=...).filter(filter_fn=...)
return datapipe


if __name__ == "__main__":
cuts_path = Path(
"./workspace/LibriSpeech/manifests/mini-libri-cuts-train-clean-5.jsonl.gz"
)

if not cuts_path.exists():
print("Downloading Mini LibriSpeech.")
download_librispeech("./workspace", dataset_parts="mini_librispeech")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could potentially replaced by the on_disk_cache [WIP for refactoring] with download and hash check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds interesting, do you have a pointer to some code/doc?

print("Building Mini LibriSpeech manifests.")
libri = prepare_librispeech(
corpus_dir="./workspace/LibriSpeech",
output_dir="./workspace/LibriSpeech/manifests",
)
print("Storing CutSet.")
cuts = CutSet.from_manifests(**libri["train-clean-5"])
cuts.to_file(cuts_path)

pipeline = build_pipeline(cuts_path, num_jobs=2)
for idx, item in enumerate(pipeline):
print(idx, item)