-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[trainer] fix: use dp_size instead of world_size in _balance_batch #4697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[trainer] fix: use dp_size instead of world_size in _balance_batch #4697
Conversation
The _balance_batch function is meant to "reorder the data such that each dp rank gets similar total tokens". However, it was incorrectly using world_size for partitioning, which includes tensor/pipeline parallel dimensions. This fix retrieves the actual dp_size from the dispatch info, which correctly represents the number of data parallel ranks. This ensures proper load balancing across DP ranks when using model parallelism. Fixes volcengine#4538 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a bug in _balance_batch by using dp_size instead of world_size for partitioning data. This is crucial for correct load balancing when tensor or pipeline parallelism is used. The logic for obtaining dp_size is sound. I've suggested a minor refactoring to reduce code duplication, which will improve maintainability.
verl/trainer/ppo/ray_trainer.py
Outdated
| if "actor" in self.actor_rollout_wg._dispatch_info: | ||
| dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"] | ||
| dp_size = max(dp_rank_mapping) + 1 | ||
| else: | ||
| # Query dispatch info if not yet available | ||
| dp_rank_mapping = self.actor_rollout_wg._query_dispatch_info("actor") | ||
| self.actor_rollout_wg._dispatch_info["actor"] = dp_rank_mapping | ||
| dp_size = max(dp_rank_mapping) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to get dp_rank_mapping and calculate dp_size is duplicated in both the if and else blocks. This can be refactored to be more concise and avoid repetition, improving code maintainability.
if "actor" not in self.actor_rollout_wg._dispatch_info:
# Query dispatch info if not yet available
self.actor_rollout_wg._dispatch_info["actor"] = self.actor_rollout_wg._query_dispatch_info("actor")
dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"]
dp_size = max(dp_rank_mapping) + 1| global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) | ||
| workload_lst = calculate_workload(global_seqlen_lst) | ||
| world_size = self.actor_rollout_wg.world_size | ||
| # Get dp_size from dispatch info to correctly balance across data parallel ranks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you encapsulate this into a function?
|
@yurekami Please format code according to: https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting |
|
Thank you for the review feedback! I've addressed both comments in the latest commit:
The refactored code is cleaner and more maintainable. Please let me know if any further changes are needed! |
Address review feedback: - Extract duplicate dp_size calculation into _get_dp_size() method - Apply ruff formatting per CONTRIBUTING.md guidelines 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
|
|
|
Hi maintainers! 👋 I've pushed the refactoring changes requested in the review:
Could you please approve the CI workflows to run? Thank you! |
Summary
_balance_batchfunction to usedp_sizeinstead ofworld_sizefor partitioningworld_sizeis incorrect when tensor/pipeline parallelism is used, as it includes non-DP dimensionsRoot Cause
The
_balance_batchfunction usedworld_sizeforget_seqlen_balanced_partitions, butworld_sizemay include tensor parallel and pipeline parallel dimensions. When using model parallelism, this causes incorrect load balancing.Changes
dp_sizefrom the worker group's_dispatch_infoinstead of usingworld_sizedp_sizeconsistently for all partitioning operationsBefore
After
Fixes #4538
Test plan
🤖 Generated with Claude Code