diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 84e651f9171..94354ee2043 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -176,6 +176,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth - skip - take - shard + - reshard - repeat - to_csv - to_pandas diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx index b721b0959c4..5545e5f8ef4 100644 --- a/docs/source/stream.mdx +++ b/docs/source/stream.mdx @@ -182,7 +182,21 @@ IterableDataset({ }) ``` -If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead. +To increase the number of shards of a dataset, you can use [`IterableDataset.reshard`]: + +```py +>>> dataset.reshard() +IterableDataset({ + features: ['label', 'title', 'content'], + num_shards: 3600 +}) +``` + +The resharding mechanism depends on the dataset file format. +For example for Parquet, it reshards using row groups instead of having one file per shard. +See how it works for every format in [`IterableDataset.reshard`]'s documentation. + +If your dataset has `dataset.num_shards==1` even after resharding, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead. ## Interleave diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx index 8612f35bd27..8e77b883805 100644 --- a/docs/source/use_with_pytorch.mdx +++ b/docs/source/use_with_pytorch.mdx @@ -255,3 +255,6 @@ then the shards are evenly assigned across the nodes, which is the most optimize Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data. + +> [!WARNING] +> If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`] so the same shuffled list of shards is used on every node to know which shards the node should skip. diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 363367a0dc0..13e99d3f7ab 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1638,7 +1638,11 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_ ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: - return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs) + return ExamplesIterable( + self._generate_examples, + split_generator.gen_kwargs, + generate_more_kwargs_fn=getattr(self, "_generate_more_gen_kwargs", None), + ) class ArrowBasedBuilder(DatasetBuilder): @@ -1933,7 +1937,11 @@ def _prepare_split_single( ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: - return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs) + return ArrowExamplesIterable( + self._generate_tables, + kwargs=split_generator.gen_kwargs, + generate_more_kwargs_fn=getattr(self, "_generate_more_gen_kwargs", None), + ) class _CountableBuilderMixin(DatasetBuilder): diff --git a/src/datasets/combine.py b/src/datasets/combine.py index 91a6457c02c..7b6c64b4cc1 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -39,8 +39,10 @@ def interleave_datasets( Note for iterable datasets: - In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process. - Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker). + * The resulting dataset's `num_shards` is the minimum of each dataset's `num_shards` to ensure good parallelism. + If some of your datasets have a very low number of shards, you may use [`IterableDataset.reshard`]. + * In a distributed setup or in PyTorch DataLoader workers, the stopping strategy is applied per process. + Therefore the "first_exhausted" strategy on an sharded iterable dataset can generate less samples in total (up to 1 missing sample per subdataset per worker). Args: datasets (`List[Dataset]` or `List[IterableDataset]`): @@ -170,10 +172,17 @@ def concatenate_datasets( axis: int = 0, ) -> DatasetType: """ - Converts a list of [`Dataset`] with the same schema into a single [`Dataset`]. + Concatenate several datasets (sources) into a single dataset. + + Use axis=0 to concatenate vertically (default), or axis=1 to concatenate horizontally. + + Note for iterable datasets: + + * if axis=0, the resulting dataset's `num_shards` is the sum of each dataset's `num_shards`. + * if axis=1, the resulting dataset has one (1) shard to not misalign data. Args: - dsets (`List[datasets.Dataset]`): + dsets (`List[datasets.Dataset]` or `List[datasets.IterableDataset]`): List of Datasets to concatenate. info (`DatasetInfo`, *optional*): Dataset information, like description, citation, etc. diff --git a/src/datasets/distributed.py b/src/datasets/distributed.py index 4697948f342..a9ca59c72ff 100644 --- a/src/datasets/distributed.py +++ b/src/datasets/distributed.py @@ -22,6 +22,10 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. + > [!WARNING] + > If you shuffle your iterable dataset in a distributed setup, make sure to set a fixed `seed` in [`IterableDataset.shuffle`] + so the same shuffled list of shards is used on every node to know which shards the node should skip. + Args: dataset ([`Dataset`] or [`IterableDataset`]): The dataset to split by node. diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 9436fb89653..34dc983a1e7 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -179,6 +179,8 @@ def set_seed_recursively(ex_iterable): ex_iterable = ex_iterable.shift_rngs(value) if hasattr(ex_iterable, "ex_iterable"): ex_iterable.ex_iterable = set_seed_recursively(ex_iterable.ex_iterable) + if hasattr(ex_iterable, "ex_iterables"): + ex_iterable.ex_iterables = [set_seed_recursively(ei) for ei in ex_iterable.ex_iterables] return ex_iterable return set_seed_recursively(ex_iterable) @@ -217,6 +219,14 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_ """Either keep only the requested shard, or propagate the request to the underlying iterable.""" raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") + def reshard_data_sources(self) -> "_BaseExamplesIterable": + """ + Either reshard the shards/sources of the dataset, i.e. further split the current shards into more shards, + or propagate the resharding to the underlying iterable. + If the examples iterable can't be further resharded, then this method returns self. + """ + raise NotImplementedError(f"{type(self)} doesn't implement reshard_data_sources yet") + def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> list[int]: if contiguous: div = self.num_shards // num_shards @@ -255,11 +265,19 @@ def state_dict(self) -> dict: class ExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_examples_fn: Callable[..., Iterator[tuple[Key, dict]]], kwargs: dict): + def __init__( + self, + generate_examples_fn: Callable[..., Iterator[tuple[Key, dict]]], + kwargs: dict, + generate_more_kwargs_fn: Optional[Callable[..., Iterator[dict]]] = None, + ): super().__init__() self.generate_examples_fn = generate_examples_fn self.kwargs = kwargs + # for resharding + self.generate_more_kwargs_fn = generate_more_kwargs_fn + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -277,74 +295,52 @@ def __iter__(self): self._state_dict["shard_example_idx"] = 0 def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": - return ShuffledDataSourcesExamplesIterable(self.generate_examples_fn, self.kwargs, generator) + return ExamplesIterable( + self.generate_examples_fn, + _shuffle_gen_kwargs(copy.deepcopy(generator), self.kwargs), + self.generate_more_kwargs_fn, + ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) - return ExamplesIterable(self.generate_examples_fn, requested_gen_kwargs) + return ExamplesIterable(self.generate_examples_fn, requested_gen_kwargs, self.generate_more_kwargs_fn) + + def reshard_data_sources(self) -> "ExamplesIterable": + """Split shars into more shards if possible.""" + if not self.generate_more_kwargs_fn: + return ExamplesIterable(self.generate_examples_fn, self.kwargs, self.generate_more_kwargs_fn) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) + new_gen_kwargs = _merge_gen_kwargs( + [ + new_gen_kwargs + for gen_kwargs in gen_kwargs_list + for new_gen_kwargs in self.generate_more_kwargs_fn(**gen_kwargs) + ] + ) + return ExamplesIterable(self.generate_examples_fn, new_gen_kwargs, self.generate_more_kwargs_fn) @property def num_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) -class ShuffledDataSourcesExamplesIterable(ExamplesIterable): +class ArrowExamplesIterable(_BaseExamplesIterable): def __init__( self, - generate_examples_fn: Callable[..., Iterator[tuple[Key, dict]]], + generate_tables_fn: Callable[..., Iterator[tuple[Key, pa.Table]]], kwargs: dict, - generator: np.random.Generator, + generate_more_kwargs_fn: Optional[Callable[..., Iterator[dict]]] = None, ): - super().__init__(generate_examples_fn, kwargs) - self.generator = deepcopy(generator) - - def shift_rngs(self, value: int) -> "_BaseExamplesIterable": - new_seed = self.generator.bit_generator.state["state"]["state"] + value - return ShuffledDataSourcesExamplesIterable( - self.generate_examples_fn, - self.kwargs, - np.random.default_rng(seed=new_seed), - ) - - def _init_state_dict(self) -> dict: - self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} - return self._state_dict - - def __iter__(self): - """Shuffle the kwargs order to shuffle shards""" - rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) - shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None - ): - shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 - for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None): - if self._state_dict: - self._state_dict["shard_example_idx"] += 1 - yield key_example - if self._state_dict: - self._state_dict["shard_idx"] += 1 - self._state_dict["shard_example_idx"] = 0 - - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": - """Keep only the requested shard.""" - rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) - return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources( - num_shards, index, contiguous=contiguous - ) - - -class ArrowExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_tables_fn: Callable[..., Iterator[tuple[Key, pa.Table]]], kwargs: dict): super().__init__() self.generate_tables_fn = generate_tables_fn self.kwargs = kwargs + # for resharding + self.generate_more_kwargs_fn = generate_more_kwargs_fn + @property def iter_arrow(self): return self._iter_arrow @@ -392,7 +388,9 @@ def _iter_arrow(self): self._state_dict["shard_example_idx"] = 0 def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamplesIterable": - return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator) + return ArrowExamplesIterable( + self.generate_tables_fn, _shuffle_gen_kwargs(copy.deepcopy(generator), self.kwargs), generator + ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" @@ -401,89 +399,25 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "A requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) return ArrowExamplesIterable(self.generate_tables_fn, requested_gen_kwargs) + def reshard_data_sources(self) -> "ArrowExamplesIterable": + """Split shars into more shards if possible.""" + if not self.generate_more_kwargs_fn: + return ArrowExamplesIterable(self.generate_tables_fn, self.kwargs, self.generate_more_kwargs_fn) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) + new_gen_kwargs = _merge_gen_kwargs( + [ + new_gen_kwargs + for gen_kwargs in gen_kwargs_list + for new_gen_kwargs in self.generate_more_kwargs_fn(**gen_kwargs) + ] + ) + return ArrowExamplesIterable(self.generate_tables_fn, new_gen_kwargs, self.generate_more_kwargs_fn) + @property def num_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) -class ShuffledDataSourcesArrowExamplesIterable(ArrowExamplesIterable): - def __init__( - self, - generate_tables_fn: Callable[..., Iterator[tuple[Key, pa.Table]]], - kwargs: dict, - generator: np.random.Generator, - ): - super().__init__(generate_tables_fn, kwargs) - self.generator = deepcopy(generator) - - def shift_rngs(self, value: int) -> "_BaseExamplesIterable": - new_seed = self.generator.bit_generator.state["state"]["state"] + value - return ShuffledDataSourcesArrowExamplesIterable( - self.generate_examples_fn, - self.kwargs, - np.random.default_rng(seed=new_seed), - ) - - def _init_state_dict(self) -> dict: - self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} - return self._state_dict - - def __iter__(self): - """Shuffle the kwargs order to shuffle shards""" - rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) - formatter = PythonFormatter() - shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None - ): - shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 - shard_example_idx = 0 - for key, pa_table in self.generate_tables_fn(**gen_kwags): - if shard_example_idx + len(pa_table) <= shard_example_idx_start: - shard_example_idx += len(pa_table) - continue - for pa_subtable in pa_table.to_reader(max_chunksize=config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): - formatted_batch = formatter.format_batch(pa_subtable) - for example in _batch_to_examples(formatted_batch): - if shard_example_idx >= shard_example_idx_start: - if self._state_dict: - self._state_dict["shard_example_idx"] += 1 - yield key, example - shard_example_idx += 1 - if self._state_dict: - self._state_dict["shard_idx"] += 1 - self._state_dict["shard_example_idx"] = 0 - - def _iter_arrow(self): - rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) - shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None - ): - shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 - shard_example_idx = 0 - for key, pa_table in self.generate_tables_fn(**gen_kwags): - shard_example_idx += len(pa_table) - if shard_example_idx <= shard_example_idx_start: - continue - if self._state_dict: - self._state_dict["shard_example_idx"] += len(pa_table) - yield key, pa_table - if self._state_dict: - self._state_dict["shard_idx"] += 1 - self._state_dict["shard_example_idx"] = 0 - - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": - """Keep only the requested shard.""" - rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) - return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( - num_shards, index, contiguous=contiguous - ) - - class RebatchedArrowExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int], drop_last_batch: bool = False): super().__init__() @@ -615,6 +549,13 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "R self.drop_last_batch, ) + def reshard_data_sources(self) -> "RebatchedArrowExamplesIterable": + return RebatchedArrowExamplesIterable( + self.ex_iterable.reshard_data_sources(), + self.batch_size, + self.drop_last_batch, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -660,6 +601,9 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "S self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names ) + def reshard_data_sources(self) -> "SelectColumnsIterable": + return SelectColumnsIterable(self.ex_iterable.reshard_data_sources(), self.column_names) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -706,6 +650,13 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "S offset=self.offset, ) + def reshard_data_sources(self) -> "StepExamplesIterable": + return StepExamplesIterable( + self.ex_iterable.reshard_data_sources(), + step=self.step, + offset=self.offset, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -887,6 +838,12 @@ def shard_data_sources( stopping_strategy=self.stopping_strategy, ) + def reshard_data_sources(self) -> "CyclingMultiSourcesExamplesIterable": + return CyclingMultiSourcesExamplesIterable( + [iterable.reshard_data_sources() for iterable in self.ex_iterables], + stopping_strategy=self.stopping_strategy, + ) + class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): """ @@ -943,23 +900,37 @@ def _iter_arrow(self): def shuffle_data_sources( self, generator: np.random.Generator ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": - """Shuffle the list of examples iterable, as well as each underlying examples iterable.""" + """Shuffle all shards.""" rng = deepcopy(generator) - ex_iterables = list(self.ex_iterables) - rng.shuffle(ex_iterables) - ex_iterables = [ex_iterable.shuffle_data_sources(generator) for ex_iterable in ex_iterables] - return VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables) + single_shard_ex_iterables = [ + ex_iterable.shard_data_sources(num_shards=ex_iterable.num_shards, index=index) + for ex_iterable in self.ex_iterables + for index in range(ex_iterable.num_shards) + ] + rng.shuffle(single_shard_ex_iterables) + return VerticallyConcatenatedMultiSourcesExamplesIterable(single_shard_ex_iterables) @property def num_shards(self) -> int: - return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) + return sum(ex_iterable.num_shards for ex_iterable in self.ex_iterables) def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": - """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + """Keep only the requested shard""" + single_shard_ex_iterables = [ + ex_iterable.shard_data_sources(num_shards=ex_iterable.num_shards, index=index) + for ex_iterable in self.ex_iterables + for index in range(ex_iterable.num_shards) + ] + shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) + return VerticallyConcatenatedMultiSourcesExamplesIterable( + [single_shard_ex_iterables[i] for i in shard_indices] + ) + + def reshard_data_sources(self) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": return VerticallyConcatenatedMultiSourcesExamplesIterable( - [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] + [iterable.reshard_data_sources() for iterable in self.ex_iterables] ) @@ -1045,10 +1016,12 @@ def num_shards(self) -> int: def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": - """Either keep only the requested shard, or propagate the request to the underlying iterable.""" - return HorizontallyConcatenatedMultiSourcesExamplesIterable( - [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] - ) + """Doesn't shard the wrapped examples iterable since it would break the alignment between them.""" + return self + + def reshard_data_sources(self) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": + """Doesn't reshard the wrapped examples iterable since it would break the alignment between them.""" + return self class RandomlyCyclingMultiSourcesExamplesIterable(CyclingMultiSourcesExamplesIterable): @@ -1066,7 +1039,8 @@ def __init__( self.probabilities = probabilities def shift_rngs(self, value: int) -> "_BaseExamplesIterable": - new_seed = self.generator.bit_generator.state["state"]["state"] + value + rng = deepcopy(self.generator) + new_seed = rng.integers(0, 1 << 63) - value return RandomlyCyclingMultiSourcesExamplesIterable( ex_iterables=self.ex_iterables, generator=np.random.default_rng(seed=new_seed), @@ -1164,6 +1138,15 @@ def shard_data_sources( self.stopping_strategy, ) + def reshard_data_sources(self) -> "RandomlyCyclingMultiSourcesExamplesIterable": + """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + return RandomlyCyclingMultiSourcesExamplesIterable( + [iterable.reshard_data_sources() for iterable in self.ex_iterables], + self.generator, + self.probabilities, + self.stopping_strategy, + ) + def _table_output_to_arrow(output) -> pa.Table: if isinstance(output, pa.Table): @@ -1557,6 +1540,22 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "M max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel, ) + def reshard_data_sources(self) -> "MappedExamplesIterable": + return MappedExamplesIterable( + self.ex_iterable.reshard_data_sources(), + function=self.function, + with_indices=self.with_indices, + input_columns=self.input_columns, + batched=self.batched, + batch_size=self.batch_size, + drop_last_batch=self.drop_last_batch, + remove_columns=self.remove_columns, + fn_kwargs=self.fn_kwargs, + formatting=self.formatting, + features=self.features, + max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -1659,6 +1658,18 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "F formatting=self.formatting, ) + def reshard_data_sources(self) -> "FilteredExamplesIterable": + return FilteredExamplesIterable( + self.ex_iterable.reshard_data_sources(), + function=self.mask_function, + with_indices=self.with_indices, + input_columns=self.input_columns, + batched=self.batched, + batch_size=self.batch_size, + fn_kwargs=self.fn_kwargs, + formatting=self.formatting, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -1672,7 +1683,8 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat self.generator = generator def shift_rngs(self, value: int) -> "_BaseExamplesIterable": - new_seed = self.generator.bit_generator.state["state"]["state"] + value + rng = deepcopy(self.generator) + new_seed = rng.integers(0, 1 << 63) - value return BufferShuffledExamplesIterable( ex_iterable=self.ex_iterable, buffer_size=self.buffer_size, @@ -1747,7 +1759,7 @@ def _iter_arrow(self): def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffledExamplesIterable": """Shuffle the wrapped examples iterable as well as the shuffling buffer.""" return BufferShuffledExamplesIterable( - self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator + self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=self.generator ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": @@ -1758,6 +1770,13 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "B generator=self.generator, ) + def reshard_data_sources(self) -> "BufferShuffledExamplesIterable": + return BufferShuffledExamplesIterable( + self.ex_iterable.reshard_data_sources(), + buffer_size=self.buffer_size, + generator=self.generator, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -1833,6 +1852,14 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "S else: return self + def reshard_data_sources(self) -> "SkipExamplesIterable": + return SkipExamplesIterable( + self.ex_iterable.reshard_data_sources(), + n=self.n, + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -1882,6 +1909,12 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "R num_times=self.num_times, ) + def reshard_data_sources(self) -> "RepeatExamplesIterable": + return RepeatExamplesIterable( + self.ex_iterable.reshard_data_sources(), + num_times=self.num_times, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -1963,6 +1996,14 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "T split_when_sharding=self.split_when_sharding, ) + def reshard_data_sources(self) -> "TakeExamplesIterable": + return TakeExamplesIterable( + self.ex_iterable.reshard_data_sources(), + n=self.n, + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards @@ -2110,17 +2151,19 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "F formatting=self.formatting, ) + def reshard_data_sources(self) -> "FormattedExamplesIterable": + return FormattedExamplesIterable( + self.ex_iterable.shard_data_sources(), + features=self.features, + token_per_repo_id=self.token_per_repo_id, + formatting=self.formatting, + ) + @property def num_shards(self) -> int: return self.ex_iterable.num_shards -@dataclass -class ShufflingConfig: - generator: np.random.Generator - _original_seed: Optional[int] = None - - @dataclass class DistributedConfig: rank: int @@ -2190,22 +2233,14 @@ def __init__( info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, formatting: Optional[FormattingConfig] = None, - shuffling: Optional[ShufflingConfig] = None, distributed: Optional[DistributedConfig] = None, token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None, ): - if distributed and distributed.world_size > 1 and shuffling and shuffling._original_seed is None: - raise RuntimeError( - "The dataset doesn't have a fixed random seed across nodes to shuffle and split the list of dataset shards by node. " - "Please pass e.g. `seed=42` in `.shuffle()` to make all the nodes use the same seed. " - ) - info = info.copy() if info is not None else DatasetInfo() DatasetInfoMixin.__init__(self, info=info, split=split) self._ex_iterable = copy.copy(ex_iterable) self._formatting = formatting - self._shuffling = shuffling self._distributed = distributed self._token_per_repo_id: dict[str, Union[str, bool, None]] = token_per_repo_id or {} self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) @@ -2371,17 +2406,6 @@ def _head(self, n=5): def epoch(self) -> int: return int(self._epoch) - def _effective_generator(self): - if self._shuffling and self.epoch == 0: - return self._shuffling.generator - elif self._shuffling: - # Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars) - effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch - effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed - return np.random.default_rng(effective_seed) - else: - raise ValueError("This dataset is not shuffled") - @property def num_shards(self) -> int: if self._distributed and self._ex_iterable.num_shards % self._distributed.world_size == 0: @@ -2409,7 +2433,7 @@ def _iter_pytorch(self): logger.info( f"To parallelize data loading, we give each process some shards (or data sources) to process. " f"Therefore it's unnecessary to have a number of workers greater than dataset.num_shards={ex_iterable.num_shards}. " - f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards}." + f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards} or try `dataset = dataset.reshard()` which may increase `num_shards` depending on the dataset file format." ) # split workload _log_prefix = f"node#{self._distributed.rank} " if self._distributed else "" @@ -2475,10 +2499,9 @@ def _prepare_ex_iterable_for_iteration( ex_iterable = RebatchedArrowExamplesIterable( ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch ) - if self._shuffling: - ex_iterable = ex_iterable.shuffle_data_sources(self._effective_generator()) - else: - ex_iterable = ex_iterable + if self.epoch: + ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(self.epoch)) + ex_iterable = shift_ex_examples_rngs(ex_iterable, self.epoch) if self._distributed: rank = self._distributed.rank @@ -2749,7 +2772,6 @@ def with_format( info=self._info.copy(), split=self._split, formatting=FormattingConfig(format_type=type), - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -2901,7 +2923,6 @@ def map( info=info, split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -2987,7 +3008,6 @@ def filter( info=self._info, split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3046,19 +3066,17 @@ def shuffle( generator = np.random.default_rng(seed) else: generator = deepcopy(generator) - shuffling = ShufflingConfig(generator=generator, _original_seed=seed) return IterableDataset( BufferShuffledExamplesIterable( - RebatchedArrowExamplesIterable(self._ex_iterable, batch_size=1) + RebatchedArrowExamplesIterable(self._ex_iterable.shuffle_data_sources(generator), batch_size=1) if self._ex_iterable.iter_arrow - else self._ex_iterable, + else self._ex_iterable.shuffle_data_sources(generator), buffer_size=buffer_size, generator=generator, ), info=self._info.copy(), split=self._split, formatting=self._formatting, - shuffling=shuffling, distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3097,7 +3115,6 @@ def skip(self, n: int) -> "IterableDataset": ex_iterable = SkipExamplesIterable( self._ex_iterable, n, - block_sources_order_when_shuffling=self._shuffling is None, split_when_sharding=self._distributed is None, ) return IterableDataset( @@ -3105,7 +3122,6 @@ def skip(self, n: int) -> "IterableDataset": info=self._info.copy(), split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3148,7 +3164,6 @@ def repeat(self, num_times: Optional[int]) -> "IterableDataset": info=self._info, split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3177,7 +3192,6 @@ def take(self, n: int) -> "IterableDataset": ex_iterable = TakeExamplesIterable( self._ex_iterable, n, - block_sources_order_when_shuffling=self._shuffling is None, split_when_sharding=self._distributed is None, ) return IterableDataset( @@ -3185,7 +3199,6 @@ def take(self, n: int) -> "IterableDataset": info=self._info.copy(), split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3225,12 +3238,12 @@ def shard( >>> from datasets import load_dataset >>> ds = load_dataset("fancyzhx/amazon_polarity", split="train", streaming=True) >>> ds - Dataset({ + IterableDataset({ features: ['label', 'title', 'content'], num_shards: 4 }) >>> ds.shard(num_shards=2, index=0) - Dataset({ + IterableDataset({ features: ['label', 'title', 'content'], num_shards: 2 }) @@ -3242,7 +3255,46 @@ def shard( info=self._info.copy(), split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), + token_per_repo_id=self._token_per_repo_id, + ) + + def reshard(self) -> "IterableDataset": + """Reshard the dataset if possible, i.e. split the current shards further into more shards. + This increases the number of shards and the resulting dataset has num_shards >= previous_num_shards. + Equality may happen if no shard can be split further. + + The resharding mechanism depends on the dataset file format: + + * Parquet: shard per row group instead of per file + * Other: not implemented yet (contributions are welcome !) + + Be sure to reshard/shard before using any randomizing operator (such as `shuffle`). + It is best if the shard operator is used early in the dataset pipeline. + + Example: + + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("fancyzhx/amazon_polarity", split="train", streaming=True) + >>> ds + IterableDataset({ + features: ['label', 'title', 'content'], + num_shards: 4 + }) + >>> ds.reshard() + IterableDataset({ + features: ['label', 'title', 'content'], + num_shards: 3600 + }) + ``` + """ + ex_iterable = self._ex_iterable.reshard_data_sources() + return IterableDataset( + ex_iterable=ex_iterable, + info=self._info.copy(), + split=self._split, + formatting=self._formatting, distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3395,7 +3447,6 @@ def select_columns(self, column_names: Union[str, list[str]]) -> "IterableDatase info=info, split=self._split, formatting=self._formatting, - shuffling=self._shuffling, distributed=self._distributed, token_per_repo_id=self._token_per_repo_id, ) @@ -3442,7 +3493,6 @@ def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset": info=info, split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3489,7 +3539,6 @@ def cast( info=info, split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3594,7 +3643,6 @@ def _step(self, step: int, offset: int) -> "IterableDataset": info=self._info.copy(), split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -3613,7 +3661,6 @@ def _resolve_features(self): info=info, split=self._split, formatting=self._formatting, - shuffling=copy.deepcopy(self._shuffling), distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -4627,6 +4674,11 @@ def _interleave_iterable_datasets( # Perform checks _check_if_features_can_be_aligned([dset.features for dset in datasets]) + for i, dset in enumerate(datasets): + if datasets[0]._distributed != dset._distributed: + raise ValueError( + f"Datasets should be identically split_by_node before interleaving, but got {datasets[0]._distributed}!={dset._distributed} at index 0 and {i}" + ) # TODO: improve this to account for a mix of ClassLabel and Value for example # right now it would keep the type of the first dataset in the list @@ -4660,7 +4712,13 @@ def _interleave_iterable_datasets( repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items() } # Return new daset - return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + return IterableDataset( + ex_iterable=ex_iterable, + info=info, + split=split, + token_per_repo_id=token_per_repo_id, + distributed=datasets[0]._distributed, + ) def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_size: int) -> IterableDataset: @@ -4691,7 +4749,6 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s info=dataset._info.copy(), split=dataset._split, formatting=dataset._formatting, - shuffling=copy.deepcopy(dataset._shuffling), distributed=distributed, token_per_repo_id=dataset._token_per_repo_id, ) diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 62d76c8c64e..15f6b01008b 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -129,7 +129,11 @@ def _split_generators(self, dl_manager): raise ValueError( f"At least one valid data file must be specified, all the data_files are invalid: {self.config.data_files}" ) - splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + splits.append( + datasets.SplitGenerator( + name=split_name, gen_kwargs={"files": files, "row_groups_list": [None] * len(files)} + ) + ) if self.config.columns is not None and set(self.config.columns) != set(self.info.features): self.info.features = datasets.Features( {col: feat for col, feat in self.info.features.items() if col in self.config.columns} @@ -143,10 +147,33 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, self.info.features.arrow_schema) return pa_table - def _generate_shards(self, files): - yield from files - - def _generate_tables(self, files): + def _generate_shards(self, files, row_groups_list): + if not row_groups_list: + yield from files + else: + for file, row_groups in zip(files, row_groups_list): + yield { + "fragment_data_file": file, + "fragment_row_groups": row_groups, + } + + def _generate_more_gen_kwargs(self, files, row_groups_list): + if not row_groups_list: + parquet_file_format = ds.ParquetFileFormat(default_fragment_scan_options=self.config.fragment_scan_options) + for file in files: + with open(file, "rb") as f: + parquet_fragment = parquet_file_format.make_fragment(f) + yield { + "files": [file] * parquet_fragment.num_row_groups, + "row_groups_list": [ + (row_group_id,) for row_group_id in range(parquet_fragment.num_row_groups) + ], + } + else: + for file, row_groups in zip(files, row_groups_list): + yield {"files": [file], "row_groups_list": [row_groups]} + + def _generate_tables(self, files, row_groups_list): if self.config.features is not None and self.config.columns is not None: if sorted(field.name for field in self.info.features.arrow_schema) != sorted(self.config.columns): raise ValueError( @@ -158,10 +185,12 @@ def _generate_tables(self, files): else self.config.filters ) parquet_file_format = ds.ParquetFileFormat(default_fragment_scan_options=self.config.fragment_scan_options) - for file_idx, file in enumerate(files): + for file_idx, (file, row_groups) in enumerate(zip(files, row_groups_list)): try: with open(file, "rb") as f: parquet_fragment = parquet_file_format.make_fragment(f) + if row_groups is not None: + parquet_fragment.subset(row_group_ids=row_groups) if parquet_fragment.row_groups: batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows for batch_idx, record_batch in enumerate( diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 65d2130f753..fbd3a90aed9 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -85,13 +85,9 @@ def gen(): ds_rank0 = split_dataset_by_node(full_ds, rank=0, world_size=world_size).shuffle(seed=42) assert len(list(ds_rank0)) == 1 + full_size // world_size - with pytest.raises(RuntimeError): - split_dataset_by_node(full_ds, rank=0, world_size=world_size).shuffle() ds_rank0 = split_dataset_by_node(full_ds.shuffle(seed=42), rank=0, world_size=world_size) assert len(list(ds_rank0)) == 1 + full_size // world_size - with pytest.raises(RuntimeError): - split_dataset_by_node(full_ds.shuffle(), rank=0, world_size=world_size) @pytest.mark.parametrize("streaming", [False, True]) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index bdfa60fdc01..b275cbf3165 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -40,9 +40,6 @@ RebatchedArrowExamplesIterable, RepeatExamplesIterable, SelectColumnsIterable, - ShuffledDataSourcesArrowExamplesIterable, - ShuffledDataSourcesExamplesIterable, - ShufflingConfig, SkipExamplesIterable, StepExamplesIterable, TakeExamplesIterable, @@ -55,7 +52,6 @@ from .utils import ( assert_arrow_memory_doesnt_increase, - is_rng_equal, require_dill_gt_0_3_2, require_jax, require_not_windows, @@ -1301,7 +1297,6 @@ def test_horizontally_concatenated_examples_iterable(): "ex_iterable", [ ExamplesIterable(generate_examples_fn, {}), - ShuffledDataSourcesExamplesIterable(generate_examples_fn, {}, np.random.default_rng(42)), SelectColumnsIterable(ExamplesIterable(generate_examples_fn, {}), ["id"]), StepExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 2, 0), CyclingMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), @@ -1332,7 +1327,6 @@ def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable): "ex_iterable", [ ArrowExamplesIterable(generate_tables_fn, {}), - ShuffledDataSourcesArrowExamplesIterable(generate_tables_fn, {}, np.random.default_rng(42)), SelectColumnsIterable(ArrowExamplesIterable(generate_tables_fn, {}), ["id"]), # StepExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 2, 0), # not implemented # CyclingMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), # not implemented @@ -1671,24 +1665,6 @@ def test_iterable_dataset_set_epoch_resuming(dataset: IterableDataset): assert len(list(dataset)) == 0 -@pytest.mark.parametrize("seed", [None, 42, 1337]) -@pytest.mark.parametrize("epoch", [None, 0, 1, 10]) -def test_iterable_dataset_set_epoch_of_shuffled_dataset(dataset: IterableDataset, seed, epoch): - buffer_size = 10 - shuffled_dataset = dataset.shuffle(seed, buffer_size=buffer_size) - base_generator = shuffled_dataset._shuffling.generator - if epoch is not None: - shuffled_dataset.set_epoch(epoch) - effective_generator = shuffled_dataset._effective_generator() - assert effective_generator is not None - if epoch is None or epoch == 0: - assert is_rng_equal(base_generator, shuffled_dataset._effective_generator()) - else: - assert not is_rng_equal(base_generator, shuffled_dataset._effective_generator()) - effective_seed = deepcopy(base_generator).integers(0, 1 << 63) - epoch - assert is_rng_equal(np.random.default_rng(effective_seed), shuffled_dataset._effective_generator()) - - def test_iterable_dataset_map( dataset: IterableDataset, ): @@ -1785,10 +1761,7 @@ def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch): dataset = deepcopy(dataset) dataset._ex_iterable.kwargs["filepaths"] = ["0.txt", "1.txt"] dataset = dataset.shuffle(seed, buffer_size=buffer_size) - assert isinstance(dataset._shuffling, ShufflingConfig) - assert isinstance(dataset._shuffling.generator, np.random.Generator) - assert is_rng_equal(dataset._shuffling.generator, np.random.default_rng(seed)) - # Effective seed is sum of seed and epoch + # Effective seed is mix of seed and epoch if epoch is None or epoch == 0: effective_seed = seed else: @@ -1802,7 +1775,9 @@ def test_iterable_dataset_shuffle(dataset: IterableDataset, seed, epoch): # It also shuffles the underlying examples iterable expected_ex_iterable = ExamplesIterable( generate_examples_fn, {"filepaths": ["0.txt", "1.txt"]} - ).shuffle_data_sources(np.random.default_rng(effective_seed)) + ).shuffle_data_sources(np.random.default_rng(seed)) + if epoch: + expected_ex_iterable = expected_ex_iterable.shuffle_data_sources(np.random.default_rng(epoch)) assert isinstance(dataset._ex_iterable.ex_iterable, ExamplesIterable) assert next(iter(dataset)) == list(islice(expected_ex_iterable, expected_first_example_index + 1))[-1][1]