-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Feature request
I would like to be able to concatenate multiple IterableDataset
with possibly different features. I would like to then be able to stream the results in parallel (both using DDP and multiple workers in the pytorch DataLoader). I want the merge of datasets to be well balanced between the different processes.
Motivation
I want to train a model on a combination of datasets, which I can convert to a single representation. This applies to converting different datasets items to the same Python class, as using a tokenizer on multiple modalities.
Assuming that my original datasets are not necessarily well balanced as they may have different size and thus different number of shards, I would like the merged dataset to be distributed evenly over the multiple processes. I don't mind if it's not perfectly balanced, and as result, some workers of the torch DataLoader do nothing, as long as the DDP is properly handled causing no deadlock.
What I've tried
I've tried the two functions already provided in datasets, namely interleave_datasets
and concatenate_datasets
.
- Interleave seems to be the best approach of what I'm trying to do. However, it doesn't suit my purpose because as I understand it, it stops as soon as one of the dataset source is exhausted, or repeat the smallest source items until the largest is exhausted. I would like something in-between, similarly to what roundrobin does.
- Concatenate does not mix the data enough and one dataset may be overrepresented in some early batches.
Let's consider we have 3 datasets composed of different number of shards as follow [[s0_0, s0_1], [s1_0], [s2_0, s2_1, s2_3]], where s denotes the underlying shard, the first index the dataset and the second the shard number.
If we request 3 shards in the shard_data_source
we should obtain the following:
index 0 gets s0_0 s2_0
index 1 gets s0_1 s2_1
index 2 gets s1_0 s2_3
I started implementing the following, but I'm afraid my sharding logic is incorrect.
from copy import deepcopy
from itertools import chain, islice
import datasets
import numpy as np
from datasets import IterableDataset
from datasets.iterable_dataset import _BaseExamplesIterable
from more_itertools import roundrobin
class MixMultiSourcesExampleIterable(_BaseExamplesIterable):
def __init__(self, ex_iterables: list[_BaseExamplesIterable]):
super().__init__()
self.ex_iterables = ex_iterables
def _init_state_dict(self) -> dict:
self._state_dict = {
"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables],
"type": self.__class__.__name__,
}
return self._state_dict
@property
def num_shards(self) -> int:
return sum(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
def __iter__(self):
yield from roundrobin(*self.ex_iterables)
def shuffle_data_sources(self, generator: np.random.Generator) -> "MixMultiSourcesExampleIterable":
"""Shuffle the list of examples iterable, as well as each underlying examples iterable."""
rng = deepcopy(generator)
ex_iterables = list(self.ex_iterables)
rng.shuffle(ex_iterables)
ex_iterables = [ex_iterable.shuffle_data_sources(generator) for ex_iterable in ex_iterables]
return MixMultiSourcesExampleIterable(ex_iterables)
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MixMultiSourceExampleIterable":
"""Shard the underlying iterables in a roundrobin manner.
Let's consider we have our iterables as [[s0_0, s0_1], [s1_0], [s2_0, s2_1, s2_3]],
and we request 3 shards.
index 0 gets s0_0 s2_0
index 1 gets s0_1 s2_1
index 2 gets s1_0 s2_3
"""
return MixMultiSourcesExampleIterable(
list(
islice(
# flatten all underlying iterables
chain.from_iterable([ex_iterable.shard_data_sources(1, 0) for ex_iterable in self.ex_iterables]),
# offset the starting point by the index
index,
# take over the full list, so exhaust the iterators
None,
# step by the number of shards requested
num_shards,
)
)
)
def mix_dataset(iterable_datasets: list[datasets.IterableDataset]) -> IterableDataset:
ex_iterable = MixMultiSourcesExampleIterable([ds._ex_iterable for ds in iterable_datasets])
return IterableDataset(
ex_iterable, distributed=iterable_datasets[0]._distributed, formatting=iterable_datasets[0]._formatting
)
Questions
- Am I missing something? Is there a way to use
interleave_datasets
orconcatenate_datasets
to fit my purpose? - Would it be the right approach to spread the maximum number of underlying shards across my different processes?
Your contribution
As much as I can.