Skip to content

Fix IterableDataset state_dict shard_example_idx counting #7539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Harry-Yang0518
Copy link

Fix IterableDataset's state_dict shard_example_idx reporting

Description

This PR fixes issue #7475 where the shard_example_idx value in IterableDataset's state_dict() always equals the number of samples in a shard, even if only a few examples have been consumed.

The issue is in the _iter_arrow method of the ArrowExamplesIterable class where it updates the shard_example_idx state by the full length of the batch (len(pa_table)) even when we're only partway through processing the examples.

Changes

Modified the _iter_arrow method of ArrowExamplesIterable to:

  1. Track the actual number of examples processed
  2. Only increment the shard_example_idx by the number of examples actually yielded
  3. Handle partial batches correctly

How to Test

I've included a simple test case that demonstrates the fix:

from datasets import Dataset

# Create a test dataset
ds = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=1)

# Iterate through part of the dataset
for idx, example in enumerate(ds):
    print(example)
    if idx == 2:  # Stop after 3 examples (0, 1, 2)
        state_dict = ds.state_dict()
        print("Checkpoint state_dict:", state_dict)
        break

# Before the fix, the output would show shard_example_idx: 6
# After the fix, it shows shard_example_idx: 3, correctly reflecting the 3 processed examples

Implementation Details

  1. Added logic to track the number of examples actually seen in the current shard
  2. Modified the state update to only count examples actually yielded
  3. Improved handling of partial batches and skipped examples

This fix ensures that checkpointing and resuming works correctly with exactly the expected number of examples, rather than skipping ahead to the end of the batch.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool ! I left some comments :)

Feel free to also update your branch to include CI fixes from main

@@ -317,16 +317,40 @@ def __iter__(self):

def _iter_arrow(self):
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None):
kwargs_with_shuffled_shards = (
_shuffle_gen_kwargs(self.generator, self.kwargs) if hasattr(self, "generator") else self.kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this expected ? I think ShuffledDataSourcesArrowExamplesIterable has its own _iter_arrow() implementation

shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
shard_example_idx = 0

examples_seen_in_current_shard = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is it different from shard_example_idx ?

Comment on lines +339 to +342
if shard_example_idx < shard_example_idx_start:
offset = shard_example_idx_start - shard_example_idx
pa_table = pa_table.slice(offset)
examples_seen_in_current_shard = offset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed ? we always yield full tables, so it's unlikely we end up with a shard_example_idx that doesn't land exactly on a table boundary (except if the dataset state is manually crafted maybe)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants