Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
- skip
- take
- shard
- reshard
- repeat
- to_csv
- to_pandas
Expand Down
16 changes: 15 additions & 1 deletion docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,21 @@ IterableDataset({
})
```

If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.
To increase the number of shards of a dataset, you can use [`IterableDataset.reshard`]:

```py
>>> dataset.reshard()
IterableDataset({
features: ['label', 'title', 'content'],
num_shards: 3600
})
```

The resharding mechanism depends on the dataset file format.
For example for Parquet, it reshards using row groups instead of having one file per shard.
See how it works for every format in [`IterableDataset.reshard`]'s documentation.

If your dataset has `dataset.num_shards==1` even after resharding, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.

## Interleave

Expand Down
3 changes: 3 additions & 0 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,6 @@ then the shards are evenly assigned across the nodes, which is the most optimize
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data.

> [!WARNING]
> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`] so the same shuffled list of shards is used on every node to know which shards the node should skip.
12 changes: 10 additions & 2 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,11 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_
)

def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs)
return ExamplesIterable(
self._generate_examples,
split_generator.gen_kwargs,
generate_more_kwargs_fn=getattr(self, "_generate_more_gen_kwargs", None),
)


class ArrowBasedBuilder(DatasetBuilder):
Expand Down Expand Up @@ -1933,7 +1937,11 @@ def _prepare_split_single(
)

def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs)
return ArrowExamplesIterable(
self._generate_tables,
kwargs=split_generator.gen_kwargs,
generate_more_kwargs_fn=getattr(self, "_generate_more_gen_kwargs", None),
)


class _CountableBuilderMixin(DatasetBuilder):
Expand Down
17 changes: 13 additions & 4 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def interleave_datasets(
Note for iterable datasets:
In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
* The resulting dataset's `num_shards` is the minimum of each dataset's `num_shards` to ensure good parallelism.
If some of your datasets have a very low number of shards, you may use [`IterableDataset.reshard`].
* In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process.
Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker).
Args:
datasets (`List[Dataset]` or `List[IterableDataset]`):
Expand Down Expand Up @@ -170,10 +172,17 @@ def concatenate_datasets(
axis: int = 0,
) -> DatasetType:
"""
Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
Concatenate several datasets (sources) into a single dataset.
Use axis=0 to concatenate vertically (default), or axis=1 to concatenate horizontally.
Note for iterable datasets:
* if axis=0, the resulting dataset's `num_shards` is the sum of each dataset's `num_shards`.
* if axis=1, the resulting dataset has one (1) shard to not misalign data.
Args:
dsets (`List[datasets.Dataset]`):
dsets (`List[datasets.Dataset]` or `List[datasets.IterableDataset]`):
List of Datasets to concatenate.
info (`DatasetInfo`, *optional*):
Dataset information, like description, citation, etc.
Expand Down
4 changes: 4 additions & 0 deletions src/datasets/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

> [!WARNING]
> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`]
so the same shuffled list of shards is used on every node to know which shards the node should skip.

Args:
dataset ([`Dataset`] or [`IterableDataset`]):
The dataset to split by node.
Expand Down
Loading
Loading