Skip to content

Conversation

@voegtlel
Copy link
Collaborator

@voegtlel voegtlel commented Jul 1, 2025

Provides two strategies for synchronizing the end of the loader for repeat=False.

@voegtlel voegtlel requested a review from philipp-fischer July 2, 2025 07:30
loader = get_loader(get_train_dataset(
# Set repeat=False to avoid repeating the dataset.
# Also add RedistributeLoader to synchronize the end of rank exhaustion. Only works with initialized torch distributed.
loader = RedistributeLoader(get_loader(get_train_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better if this was transparent to the user, so make get_loader handle this internally.
And since the choice of RedistributeLoader vs. StopFirstLoader actually changes the data that's being iterated, this choice should be made in the metadataset and not in the code, I think.
As a property of blend_epochized.

I.e. blend_epochized can either be a list as before (chooses default RedistributeLoader), or it can be a dict for more customization like

    blend_epochized:
      phase_out_behavior: stop_first_loader
      datasets:
        - repetitions: 5
            path: ./coco
            # ... Other parameters
        - repetitions: 2
            path: ./coyo
        - repetitions: 1
            path: ./coyo
            split_part: val

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, we can make this be handled in get_loader. I personally would prefer it to be separate though, as it's a feature on top?

Regarding moving the configuration to the metadataset:
I see your point that this slightly modifies how the data is iterated, but I'd also argue:

  1. So far we don't really rely on torch distributed, where this piece of code is tightly bound to it.
  2. This would also disable nesting of blend_epochized, because you cannot nest different (or unconfigured) phase_out_behavior.
  3. This depends on repeat=False and doesn't make sense if repeat=True, so it's based on what the user sets in the code.
  4. At least for RedistributeLoader it should not really change the data frequency (so far the settings in the metadataset mainly focuses on data frequency / blend).
  5. If we move the boundary of metadataset to include this, then we should also have gradient accumulation, seeds, batch size, handling of incomplete batches, etc. in the config. I wouldn't want that, tbh.

Thus voting for keeping this in code, not in the metadataset config.

except StopIteration:
# print(f"[r={rank}]: StopIteration\n", end="")
self.exhausted_states[rank] = self_exhausted = 1
dist.all_reduce(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should evaluate the impact of this synchronization (which happens in every iteration!) in a real world training to see if the training speed suffers when we have many nodes and ranks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed. Didn't benchmark the impact of this yet.

return f"RedistributeLoaderState(inner_state={self.inner_state!r}, exhausted_state={self.exhausted_state!r}, overuse_count={self.overuse_count!r})"


class RedistributeLoader(Generic[T]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break our reproducible scaling!
It seems the whole redistribution loader will not work if we stop and resume with a different number of ranks.
See #80

That's a problem we need to discuss possible solutions for.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, haven't thought of this tbh. Yes, let's discuss offline

return f"StopFirstDataLoaderState(inner_state={self.inner_state!r}, iterating_from_start={self.iterating_from_start!r}, next_sample_restore_key={self.next_sample_restore_key!r})"


class StopFirstLoader(Generic[T]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue with reproducible scaling here? I think this one could be easier to get it working without breaking repro scaling.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely easier, but it's also the less useful of the two methods 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants