From 4802548cbd4612832a64fca1680693bf92c676c3 Mon Sep 17 00:00:00 2001 From: Harry Yang Date: Sun, 27 Apr 2025 16:39:36 -0400 Subject: [PATCH] Fix IterableDataset state_dict shard_example_idx counting --- src/datasets/iterable_dataset.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 423abeeb56d..5f89a2b2b40 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -317,16 +317,40 @@ def __iter__(self): def _iter_arrow(self): shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): + kwargs_with_shuffled_shards = ( + _shuffle_gen_kwargs(self.generator, self.kwargs) if hasattr(self, "generator") else self.kwargs + ) + + for gen_kwags in islice( + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None + ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 + + examples_seen_in_current_shard = 0 + for key, pa_table in self.generate_tables_fn(**gen_kwags): - shard_example_idx += len(pa_table) - if shard_example_idx <= shard_example_idx_start: + batch_size = len(pa_table) + + if shard_example_idx + batch_size <= shard_example_idx_start: + shard_example_idx += batch_size continue + + if shard_example_idx < shard_example_idx_start: + offset = shard_example_idx_start - shard_example_idx + pa_table = pa_table.slice(offset) + examples_seen_in_current_shard = offset + if self._state_dict: - self._state_dict["shard_example_idx"] += len(pa_table) + examples_in_current_batch = len(pa_table) + self._state_dict["shard_example_idx"] = ( + shard_example_idx_start + examples_seen_in_current_shard + examples_in_current_batch + ) + yield key, pa_table + shard_example_idx += batch_size + examples_seen_in_current_shard += len(pa_table) + if self._state_dict: self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0