Skip to content

Conversation

@yurekami
Copy link
Contributor

Summary

  • Fix _balance_batch function to use dp_size instead of world_size for partitioning
  • The function's purpose is to "reorder the data such that each dp rank gets similar total tokens"
  • Using world_size is incorrect when tensor/pipeline parallelism is used, as it includes non-DP dimensions

Root Cause

The _balance_batch function used world_size for get_seqlen_balanced_partitions, but world_size may include tensor parallel and pipeline parallel dimensions. When using model parallelism, this causes incorrect load balancing.

Changes

  • Retrieve dp_size from the worker group's _dispatch_info instead of using world_size
  • Query the dispatch info if not already cached
  • Use dp_size consistently for all partitioning operations

Before

world_size = self.actor_rollout_wg.world_size
# ... uses world_size for k_partitions

After

# Get dp_size from dispatch info to correctly balance across data parallel ranks
if "actor" in self.actor_rollout_wg._dispatch_info:
    dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"]
    dp_size = max(dp_rank_mapping) + 1
else:
    dp_rank_mapping = self.actor_rollout_wg._query_dispatch_info("actor")
    self.actor_rollout_wg._dispatch_info["actor"] = dp_rank_mapping
    dp_size = max(dp_rank_mapping) + 1
# ... uses dp_size for k_partitions

Fixes #4538

Test plan

  • Test with tensor parallel + data parallel configuration
  • Verify load balancing is correct when TP > 1

🤖 Generated with Claude Code

The _balance_batch function is meant to "reorder the data such that
each dp rank gets similar total tokens". However, it was incorrectly
using world_size for partitioning, which includes tensor/pipeline
parallel dimensions.

This fix retrieves the actual dp_size from the dispatch info, which
correctly represents the number of data parallel ranks. This ensures
proper load balancing across DP ranks when using model parallelism.

Fixes volcengine#4538

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in _balance_batch by using dp_size instead of world_size for partitioning data. This is crucial for correct load balancing when tensor or pipeline parallelism is used. The logic for obtaining dp_size is sound. I've suggested a minor refactoring to reduce code duplication, which will improve maintainability.

Comment on lines 1095 to 1102
if "actor" in self.actor_rollout_wg._dispatch_info:
dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"]
dp_size = max(dp_rank_mapping) + 1
else:
# Query dispatch info if not yet available
dp_rank_mapping = self.actor_rollout_wg._query_dispatch_info("actor")
self.actor_rollout_wg._dispatch_info["actor"] = dp_rank_mapping
dp_size = max(dp_rank_mapping) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to get dp_rank_mapping and calculate dp_size is duplicated in both the if and else blocks. This can be refactored to be more concise and avoid repetition, improving code maintainability.

        if "actor" not in self.actor_rollout_wg._dispatch_info:
            # Query dispatch info if not yet available
            self.actor_rollout_wg._dispatch_info["actor"] = self.actor_rollout_wg._query_dispatch_info("actor")
        dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"]
        dp_size = max(dp_rank_mapping) + 1

global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
workload_lst = calculate_workload(global_seqlen_lst)
world_size = self.actor_rollout_wg.world_size
# Get dp_size from dispatch info to correctly balance across data parallel ranks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you encapsulate this into a function?

@wuxibin89 wuxibin89 changed the title fix(trainer): use dp_size instead of world_size in _balance_batch [trainer] fix: use dp_size instead of world_size in _balance_batch Dec 29, 2025
@wuxibin89
Copy link
Collaborator

@yurekami
Copy link
Contributor Author

Thank you for the review feedback! I've addressed both comments in the latest commit:

  1. Code formatting: Applied ruff format and ruff check --fix per the CONTRIBUTING.md guidelines. All checks pass now.

  2. Encapsulation: Extracted the duplicate dp_size calculation logic into a new _get_dp_size() helper method that:

    • Takes a worker group and role name as parameters
    • Handles the dispatch info caching logic
    • Returns the calculated dp_size
    • Includes docstring documentation

The refactored code is cleaner and more maintainable. Please let me know if any further changes are needed!

Address review feedback:
- Extract duplicate dp_size calculation into _get_dp_size() method
- Apply ruff formatting per CONTRIBUTING.md guidelines

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@yurekami
Copy link
Contributor Author

Hi maintainers! 👋

I've pushed the refactoring changes requested in the review:

  • Extracted duplicate dp_size calculation into a _get_dp_size() helper method
  • Applied ruff format and ruff check --fix per CONTRIBUTING.md guidelines

Could you please approve the CI workflows to run? Thank you!

@vermouth1992 vermouth1992 merged commit b16e048 into volcengine:main Dec 31, 2025
67 of 74 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

confused by _balance_batch calculation in RayPPOTrainer

4 participants