Skip to content

Adding Cycler, Header #1461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 29, 2025
Merged

Conversation

keunwoochoi
Copy link
Contributor

Please read through our contribution guide prior to
creating your pull request.

  • If you are adding a new node, ensure you read that section in the contribution guide, as it includes requirements for
    functionality and testing.

Following up discussion #1452 and my previous PR #1454

Changes

  • Adds Cycler, Header, Shuffler and their tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 7, 2025
@divyanshk
Copy link
Contributor

Thanks for the PR, reviewing....

Also kicking off the CI.

self._num_cycles += 1
self.source.reset(None)

# Try again - if it's empty, this will raise StopIteration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, the source shouldn't be empty after the reset. If it was empty it would raise in line 64.

If this makes sense, we can update the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is an edge-case where the source node is blank. i think that's a fair consideration (e.g., the current Iterable node also can take a blank list without any error during instantiation).

"""Get the current state of the node.

Returns:
A dictionary containing the state of the source node and number of cycles completed.
Copy link
Contributor

@divyanshk divyanshk Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we make it "Dict[str, Any] - A dictionary containing the state of the source node and number of cycles completed."? Ditto in other nodes in the PR, thanks.

Comment on lines 39 to 40
self._num_cycles = initial_state.get(self.NUM_CYCLES_KEY, 0)
self._has_started = initial_state.get(self.HAS_STARTED_KEY, False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should not have default values here. If the state is setup wrongly this can lead to unexpected behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, working on it in the next commit.

Copy link

pytorch-bot bot commented Mar 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1461

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f5c3b4a with merge base d349d80 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

self.RNG_STATE_KEY: self.rng.getstate(),
self.BUFFER_KEY: list(self.buffer),
self.NUM_SHUFFLED_KEY: self._num_shuffled,
self.RANDOM_STATE_KEY: self.rng.getstate(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RANDOM_STATE_KEY is a duplicate of RNG_STATE_KEY, let's remove RANDOM_STATE_KEY

Comment on lines 69 to 71
while len(self.buffer) < self.buffer_size:
self.buffer.append(next(self.source))
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This return True even when we do not enter the while loop. Is that expected ?

Do we want _fill_buffer() to return true if there are elements in the buffer or the call to the function led to elements being added ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! i've updated the method to explicitly handle when the buffer is already full.

the return value now clearly indicates whether the buffer has items after the call (True) or is empty (False). this matches how it's used in next() where we only raise StopIteration if both the buffer is empty and fill_buffer() returns False.

RNG_STATE_KEY = "rng_state"
BUFFER_KEY = "buffer"
NUM_SHUFFLED_KEY = "num_shuffled"
RANDOM_STATE_KEY = "random_state"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets include num yielded in state as well.

return {
self.SOURCE_KEY: self.source.state_dict(),
self.RNG_STATE_KEY: self.rng.getstate(),
self.BUFFER_KEY: list(self.buffer),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid keeping the entire buffer as state ? since this would be a list of node objects it might not make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. updated in the next commit.


Args:
source_node (BaseNode[T]): The source node to pull items from.
buffer_size (int): Size of the buffer used for shuffling. Must be at least 1.
Copy link
Contributor

@divyanshk divyanshk Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking from a user POV. Would they have a strong opinion on what the buffer_size argument should be? Maybe not.

I fear this would end up being set as an arbitrarily large number just to maximize shuffling capacity.

Should we see if we can make it work without buffer_size and seed as an argument? We have something like this in MultiNodeWeightedSampler (link below)

Update: updated link: https://github.com/pytorch/data/blob/main/torchdata/nodes/samplers/multi_node_weighted_sampler.py#L223

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from my experience, i actually believe we should definitely let user control this. the ideal shuffle buffer size depends a lot per data source in my use-cases.
for example, if a data source is not pre-shuffled globally and only sharded (which sucks but it would a good use-case of Shuffler,) the buffer size should be as large as the number of items in each shard.

if the buffer size is too large, yes a lot of other issues can occur. but perhaps that’s up to users really and they should understand how Shuffler works? especially since i’m not really sure there’s a non-trivial method to set a value that works for majority of scenarios i can imagine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i’m not able to open the link. can you share it from the public code repo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the update. i still think the buffer size is way more (indeed, 100%) up to the user unlike the example where some nice parameter can help things work faster; without hurting affecting any of the core feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep a constant as a default ? I don't want most users thinking too much about this when using the Shuffler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyanshk what would be a good constant? being someone who are not sure about having a default value, i'm prob not the best person to decide it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do 1000.

@keunwoochoi
Copy link
Contributor Author

hi @divyanshk , thanks for the careful review. i updated the code per the review comments. most of the requested changes are done, i think one exception is if we should allow like if n=0 which imo we should.

@divyanshk
Copy link
Contributor

@keunwoochoi Thank you for the updates, will get back soon!

@keunwoochoi
Copy link
Contributor Author

keunwoochoi commented Mar 31, 2025

@divyanshk gently reminding of this :)

Comment on lines 84 to 88
@parameterized.expand(itertools.product([0, 3, 7]))
def test_save_load_state(self, midpoint: int) -> None:
# This test is now expected to fail since we don't save the buffer
# in the state, which changes the behavior after loading state
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a good idea to include the test if it is meant to fail.

Regarding restoring state, since we store RNG_STATE_KEY in state, we should be able to restore where we left off right?

We should restore state else it might make the entire pipeline which has a Shuffler hard to use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops. and yes 100% agree

@divyanshk
Copy link
Contributor

divyanshk commented Mar 31, 2025

@keunwoochoi Thanks for the reminder, and for waiting.

This is looking good. Left a minor comment on adding a default value for shuffler buffer size, and using state utility function in tests. And a minor-ish comment on restoring state for shuffler using rng state. Thanks.

@keunwoochoi
Copy link
Contributor Author

@divyanshk thanks for the review again!

mostly done if not all except having a default value the shuffle buffer. but i'm sitting in a flight that is about to depart, i may wanna have another look and/or you can have a look too if you have some time ;)

@keunwoochoi
Copy link
Contributor Author

reviewed it again. looks good to me, asking for hopefully the last review :)

# Save state and create a new node
state = node.state_dict()
new_source = StatefulRangeNode(n=n)
new_node = Shuffler(new_source, buffer_size=5, seed=42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state restoration would break if someone restarts with a different buffer_size. Right ?

This is common sense that if someone changes their data pipeline after stopping, then they are bound to get different results - but on the other hand part of me thinks we should capture buffer_size as state and throw an error. Not required for this PR, maybe something for later on. LMK your thought?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.
We can add a check if shuffle params are the same or not when we try to reload.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i'm also 50:50. in some sense, i might wish i could just resume with a different shuffle buffer size. but perhaps that wish should be realized by other, general solution e.g., by adding some method to ignore the data loader states.

adding BUFFER_SIZE_KEY in the new commit.

Comment on lines 121 to 124
# The combined sequence will have fewer items than expected because
# we don't preserve the buffer in the state. We expect to lose
# approximately buffer_size items.
buffer_size = 5
Copy link
Contributor

@divyanshk divyanshk Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am missing something. Can you explain this ?

We have to ensure there is no drop in items on resumption.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a "fast forward" mode on loading state to get us to return the right elements back?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.
Intended behavior should be no difference if yielded all 20 in one run vs a run with interruptions.
A simple but inefficient solution would be to reset the source and fast forward the yielded items.
Another one (which you might have already explored) would be to store the buffer as it is.

These solutions both have their drawbacks.

  1. If we reset the sources and fast forward, then we are paying more cost when we restart, but we ONLY pay the cost iff we restart.
  2. On the other hand, if we save the buffer we pay the cost every time we do .get_state() (some people might do it every step).

Maybe we give both options to the user (might create confusion) and let them choose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intended behavior should be no difference if yielded all 20 in one run vs a run with interruptions.
+1

but overall, this is a tricky problem. i can totally imagine users appreciate a non-stateful shuffler for efficiency. i'm actually down to excluding shuffler in this PR and finish other nodes first while discussing this. what do you think?

Copy link
Contributor

@ramanishsingh ramanishsingh Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

. i'm actually down to excluding shuffler in this PR and finish other nodes first while discussing this. what do you think?

Yes, that will be easier and cleaner. Thanks for doing that!

@keunwoochoi keunwoochoi changed the title Adding Cycler, Header, Shuffler Adding Cycler, Header Apr 17, 2025
@keunwoochoi
Copy link
Contributor Author

just removed Shuffler from this PR, partly because i'm getting busier these days while i still would like to get this PR merged sooner than later.

Copy link
Contributor

@divyanshk divyanshk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you for this PR!

@keunwoochoi
Copy link
Contributor Author

nice! i'm not familiar with the further CI tests but please let me know if there's anything i can do.

@ramanishsingh
Copy link
Contributor

Some CI tests are failing due to an issue which are fixed in #1477. Your PR is fine. :)

@ramanishsingh ramanishsingh merged commit cda7d1a into pytorch:main Apr 29, 2025
60 of 62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants