Skip to content

Concatenate IterableDataset instances and distribute underlying shards in a RoundRobin manner #7792

@LTMeyer

Description

@LTMeyer

Feature request

I would like to be able to concatenate multiple IterableDataset with possibly different features. I would like to then be able to stream the results in parallel (both using DDP and multiple workers in the pytorch DataLoader). I want the merge of datasets to be well balanced between the different processes.

Motivation

I want to train a model on a combination of datasets, which I can convert to a single representation. This applies to converting different datasets items to the same Python class, as using a tokenizer on multiple modalities.

Assuming that my original datasets are not necessarily well balanced as they may have different size and thus different number of shards, I would like the merged dataset to be distributed evenly over the multiple processes. I don't mind if it's not perfectly balanced, and as result, some workers of the torch DataLoader do nothing, as long as the DDP is properly handled causing no deadlock.

What I've tried

I've tried the two functions already provided in datasets, namely interleave_datasets and concatenate_datasets.

  • Interleave seems to be the best approach of what I'm trying to do. However, it doesn't suit my purpose because as I understand it, it stops as soon as one of the dataset source is exhausted, or repeat the smallest source items until the largest is exhausted. I would like something in-between, similarly to what roundrobin does.
  • Concatenate does not mix the data enough and one dataset may be overrepresented in some early batches.

Let's consider we have 3 datasets composed of different number of shards as follow [[s0_0, s0_1], [s1_0], [s2_0, s2_1, s2_3]], where s denotes the underlying shard, the first index the dataset and the second the shard number.
If we request 3 shards in the shard_data_source we should obtain the following:
index 0 gets s0_0 s2_0
index 1 gets s0_1 s2_1
index 2 gets s1_0 s2_3

I started implementing the following, but I'm afraid my sharding logic is incorrect.

from copy import deepcopy
from itertools import chain, islice

import datasets
import numpy as np
from datasets import IterableDataset
from datasets.iterable_dataset import _BaseExamplesIterable
from more_itertools import roundrobin


class MixMultiSourcesExampleIterable(_BaseExamplesIterable):
    def __init__(self, ex_iterables: list[_BaseExamplesIterable]):
        super().__init__()
        self.ex_iterables = ex_iterables

    def _init_state_dict(self) -> dict:
        self._state_dict = {
            "ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables],
            "type": self.__class__.__name__,
        }
        return self._state_dict

    @property
    def num_shards(self) -> int:
        return sum(ex_iterable.num_shards for ex_iterable in self.ex_iterables)

    def __iter__(self):
        yield from roundrobin(*self.ex_iterables)

    def shuffle_data_sources(self, generator: np.random.Generator) -> "MixMultiSourcesExampleIterable":
        """Shuffle the list of examples iterable, as well as each underlying examples iterable."""
        rng = deepcopy(generator)
        ex_iterables = list(self.ex_iterables)
        rng.shuffle(ex_iterables)
        ex_iterables = [ex_iterable.shuffle_data_sources(generator) for ex_iterable in ex_iterables]
        return MixMultiSourcesExampleIterable(ex_iterables)

    def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MixMultiSourceExampleIterable":
        """Shard the underlying iterables in a roundrobin manner.

        Let's consider we have our iterables as [[s0_0, s0_1], [s1_0], [s2_0, s2_1, s2_3]],
        and we request 3 shards.
        index 0 gets s0_0 s2_0
        index 1 gets s0_1 s2_1
        index 2 gets s1_0 s2_3
        """
        return MixMultiSourcesExampleIterable(
            list(
                islice(
                    # flatten all underlying iterables
                    chain.from_iterable([ex_iterable.shard_data_sources(1, 0) for ex_iterable in self.ex_iterables]),
                    # offset the starting point by the index
                    index,
                    # take over the full list, so exhaust the iterators
                    None,
                    # step by the number of shards requested
                    num_shards,
                )
            )
        )


def mix_dataset(iterable_datasets: list[datasets.IterableDataset]) -> IterableDataset:
    ex_iterable = MixMultiSourcesExampleIterable([ds._ex_iterable for ds in iterable_datasets])
    return IterableDataset(
        ex_iterable, distributed=iterable_datasets[0]._distributed, formatting=iterable_datasets[0]._formatting
    )

Questions

  • Am I missing something? Is there a way to use interleave_datasets or concatenate_datasets to fit my purpose?
  • Would it be the right approach to spread the maximum number of underlying shards across my different processes?

Your contribution

As much as I can.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions