Skip to content

Add batch(by_column=...)#8172

Open
lhoestq wants to merge 1 commit intomainfrom
batch-by-column
Open

Add batch(by_column=...)#8172
lhoestq wants to merge 1 commit intomainfrom
batch-by-column

Conversation

@lhoestq
Copy link
Copy Markdown
Member

@lhoestq lhoestq commented May 4, 2026

Will be useful for robotics dataset to batch samples by episode cc @pkooij

example of usage:

from datasets import Dataset

ds = Dataset.from_dict({"episode": [0] * 10 + [1] * 10, "frame": list(range(10)) * 2})
# ds = ds.to_iterable_dataset()
ds = ds.batch(by_column="episode")
for x in ds:
    print(x)
# {'episode': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'frame': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}
# {'episode': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'frame': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

It's implemented using efficient Arrow functions for substantial speed up for Dataset and for Parquet IterableDataset.
It also supports lossless state_dict() / load_state_dict().

It works by accumulating Arrow data and grouping the batches together using pyarrow.ListArray in an arrow map() function.

Multiprocessing is not supported because it could split batches in two or more when distributing shards to processes (batches can overlap multiple shards). This is fine IMO since multiprocessing is only for Dataset.batch() and since the operation is unlikely to be CPU bound thanks to Arrow functions. Though this could be useful for very large datasets and clusters.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants