Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
InMemoryTable,
MemoryMappedTable,
Table,
_batch_accumulate_arrow_table_by_columns,
_batch_arrow_table,
_memory_mapped_record_batch_reader_from_file,
cast_array_to_feature,
Expand Down Expand Up @@ -4063,17 +4064,31 @@ def iter_outputs(shard_iterable):
@fingerprint_transform(inplace=False)
def batch(
self,
batch_size: int,
batch_size: Optional[int] = None,
by_column: Optional[Union[str, list[str]]] = None,
drop_last_batch: bool = False,
num_proc: Optional[int] = None,
new_fingerprint: Optional[str] = None,
) -> "Dataset":
"""
Group samples from the dataset into batches.

Specify either `batch_size` to make batches of `batch_size` examples.

Or use `by_column=column_name` to make batches of successive examples that share
the same value in column `column_name`.

Args:
batch_size (`int`):
batch_size (`int`, optional):
The number of samples in each batch.
by_column (`Union[str, list[str]`, optional):
The column used to batch examples together.
Successive examples with the same value for that column are in grouped the same batch.
This can also be a list of columns if you want to batch by multiple columns.
If batching by column, the batch_size is only used to control the size of internal batches
during acculumation.

<Added version="4.9.0"/>
drop_last_batch (`bool`, defaults to `False`):
Whether to drop the last incomplete batch.
num_proc (`int`, *optional*, defaults to `None`):
Expand All @@ -4096,29 +4111,30 @@ def batch(
'label': [1, 1, 1, 1]}
```
"""

def batch_fn(example):
return {k: [v] for k, v in example.items()}

if self._format_type in ("arrow", "pandas", "polars"):
return self.with_format("arrow").map(
_batch_arrow_table,
batched=True,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_proc=num_proc,
new_fingerprint=new_fingerprint,
desc="Batching examples",
if batch_size is None and by_column is None:
raise ValueError("IterableDataset.batch() misses `batch_size` or `by_column` argument.")
if by_column is not None:
if num_proc:
raise NotImplementedError("Multiprocessed batching by column is not implemented yet")
columns = [by_column] if isinstance(by_column, str) else by_column
_batch_fn = partial(
_batch_accumulate_arrow_table_by_columns, columns=columns, tables_accumulator=[], length=len(self)
)
with_indices = True
else:
_batch_fn = _batch_arrow_table
with_indices = False

return self.map(
batch_fn,
return self.with_format("arrow").map(
_batch_fn,
with_indices=with_indices,
batched=True,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_proc=num_proc,
new_fingerprint=new_fingerprint,
desc="Batching examples",
features=Features({col: List(feature) for col, feature in self.features.items()}),
)

@transmit_format
Expand Down
100 changes: 93 additions & 7 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@
from .info import DatasetInfo
from .naming import _split_re
from .splits import NamedSplit, Split, SplitInfo
from .table import _batch_arrow_table, cast_table_to_features, embed_table_storage, read_schema_from_file, table_cast
from .table import (
_batch_accumulate_arrow_table_by_columns,
_batch_arrow_table,
cast_table_to_features,
embed_table_storage,
read_schema_from_file,
table_cast,
)
from .utils import tqdm as hf_tqdm
from .utils.logging import get_logger
from .utils.py_utils import (
Expand Down Expand Up @@ -1251,6 +1258,7 @@ def __init__(
formatting: Optional["FormattingConfig"] = None,
features: Optional[Features] = None,
max_num_running_async_map_functions_in_parallel: Optional[int] = None,
is_batch_accumulate_arrow_table_function: bool = False,
):
super().__init__()
self.ex_iterable = ex_iterable
Expand All @@ -1267,6 +1275,7 @@ def __init__(
self.max_num_running_async_map_functions_in_parallel = (
max_num_running_async_map_functions_in_parallel or config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL
)
self.is_batch_accumulate_arrow_table_function = is_batch_accumulate_arrow_table_function
# sanity checks
if formatting and formatting.is_table:
# batch_size should match for iter_arrow
Expand Down Expand Up @@ -1534,6 +1543,8 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[tuple[Key
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0
tables_accumulator: list[pa.Table] = []
length: Optional[int] = None
for key, pa_table in iterator:
if (
self.batched
Expand All @@ -1554,7 +1565,9 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[tuple[Key
else:
function_args.append(current_idx)
# then apply the transform
output = self.function(*function_args, **self.fn_kwargs)
output = self.function(
*function_args, **self.fn_kwargs, tables_accumulator=tables_accumulator, length=length
)
output_table = _table_output_to_arrow(output)
if not isinstance(output_table, pa.Table):
raise TypeError(
Expand Down Expand Up @@ -1582,10 +1595,25 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[tuple[Key
num_examples_to_skip -= 1
continue
yield f"{key}_{i}", pa_subtable
if self._state_dict:
if self._state_dict:
if self.is_batch_accumulate_arrow_table_function:
output_batch_size = len(output_table)
if output_batch_size > 0: # skip checkpoint if function is accumulating
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
else:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] += len(pa_table)
if tables_accumulator:
pa_table = tables_accumulator.pop(-1)
indices = [current_idx + i for i in range(len(pa_table))]
function_args = (pa_table, indices)
output_table = self.function(
*function_args, **self.fn_kwargs, tables_accumulator=tables_accumulator, length=indices[-1] + 1
)
yield "last_batch_from_tables_accumulator", output_table

def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExamplesIterable":
"""Shuffle the wrapped examples iterable."""
Expand Down Expand Up @@ -3394,6 +3422,31 @@ def map(
{'label': 1, 'text': 'Review: effective but too-tepid biopic'}]
```
"""
return self._map(
function=function,
with_indices=with_indices,
input_columns=input_columns,
batched=batched,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
remove_columns=remove_columns,
features=features,
fn_kwargs=fn_kwargs,
)

def _map(
self,
function: Optional[Callable] = None,
with_indices: bool = False,
input_columns: Optional[Union[str, list[str]]] = None,
batched: bool = False,
batch_size: Optional[int] = 1000,
drop_last_batch: bool = False,
remove_columns: Optional[Union[str, list[str]]] = None,
features: Optional[Features] = None,
fn_kwargs: Optional[dict] = None,
is_batch_accumulate_arrow_table_function: bool = False,
) -> "IterableDataset":
if isinstance(input_columns, str):
input_columns = [input_columns]
if isinstance(remove_columns, str):
Expand Down Expand Up @@ -3455,6 +3508,7 @@ def map(
fn_kwargs=fn_kwargs,
formatting=self._formatting,
features=features,
is_batch_accumulate_arrow_table_function=is_batch_accumulate_arrow_table_function,
)
info = self.info.copy()
info.features = features
Expand Down Expand Up @@ -4205,25 +4259,57 @@ def _resolve_features(self):
token_per_repo_id=self._token_per_repo_id,
)

def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableDataset":
def batch(
self,
batch_size: Optional[int] = None,
by_column: Optional[Union[str, list[str]]] = None,
drop_last_batch: bool = False,
) -> "IterableDataset":
"""
Group samples from the dataset into batches.

Args:
batch_size (`int`): The number of samples in each batch.
drop_last_batch (`bool`, defaults to `False`): Whether to drop the last incomplete batch.
batch_size (`int`, optional):
The number of samples in each batch.
by_column (`Union[str, list[str]`, optional):
The column used to batch examples together.
Successive examples with the same value for that column are in grouped the same batch.
This can also be a list of columns if you want to batch by multiple columns.
If batching by column, the batch_size is only used to control the size of internal batches
during acculumation.

<Added version="4.9.0"/>
drop_last_batch (`bool`, defaults to `False`):
Whether to drop the last incomplete batch.

Example:
```py
>>> ds = load_dataset("some_dataset", streaming=True)
>>> batched_ds = ds.batch(batch_size=32)
```
"""

if batch_size is None and by_column is None:
raise ValueError("IterableDataset.batch() misses `batch_size` or `by_column` argument.")
if self.features:
features = Features({col: List(feature) for col, feature in self.features.items()})
else:
features = None
if by_column is not None:
columns = [by_column] if isinstance(by_column, str) else by_column
ds = (
self.with_format("arrow")
._map(
partial(_batch_accumulate_arrow_table_by_columns, columns=columns),
with_indices=True,
batched=True,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
features=features,
is_batch_accumulate_arrow_table_function=True,
)
.with_format(self._formatting.format_type if self._formatting else None)
)
return ds
if self._formatting and self._formatting.is_table:
return (
self.with_format("arrow")
Expand Down
47 changes: 47 additions & 0 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,53 @@ def _batch_arrow_table(table: pa.Table) -> pa.Table:
return pa.Table.from_arrays(batched_columns, names=table.column_names)


def _batch_accumulate_arrow_table_by_columns(
table: pa.Table, indices: list[int], columns: tuple[str], tables_accumulator: list[pa.Table], length: Optional[int]
) -> pa.Table:
accumulate_last_batch = length is None or indices[-1] + 1 < length
# keep accumulating if key is the same, otherwise include the accumulated tables
if tables_accumulator and accumulate_last_batch:
for column in columns:
if any(pc.not_equal(table[column], tables_accumulator[0][column][0]).to_pylist()):
break
else:
empty_batched_table = pa.Table.from_arrays(
[
pa.ListArray.from_arrays(pa.array([0], type=pa.int32()), table[column].chunk(0))
for column in table.column_names
],
names=table.column_names,
)
return empty_batched_table
table = pa.concat_tables(tables_accumulator + [table])
tables_accumulator[:] = []
if len(table) == 0:
return table
# cut the table per key, i.e. when the columns change value
if len(table) > 1:
cut_array = pc.not_equal(table[columns[0]][1:], table[columns[0]][:-1])
for column in columns[1:]:
cut_array = pc.or_(cut_array, pc.not_equal(table[column][1:], table[column][:-1]))
else:
cut_array = pa.array([], type=pa.uint64())
offsets = pc.indices_nonzero(cut_array)
# make the batched table
offsets = pc.add(1, offsets)
offsets = pa.concat_arrays([pa.array([0], type=pa.int32()), offsets.cast(pa.int32())])
if not accumulate_last_batch:
offsets = pa.concat_arrays([offsets, pa.array([len(table)], type=pa.int32())])
batched_columns = []
for column_name in table.column_names:
column = table[column_name].combine_chunks()
batched_columns.append(pa.ListArray.from_arrays(offsets, column))
batched_table = pa.Table.from_arrays(batched_columns, names=table.column_names)
# add the last batch to the accumulator since it might not be full yet
if accumulate_last_batch:
last_offset = offsets[-1].as_py()
tables_accumulator.append(table.slice(last_offset, len(table) - last_offset))
return batched_table


def _memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
opened_stream = _memory_mapped_record_batch_reader_from_file(filename)
pa_table = opened_stream.read_all()
Expand Down
Loading