Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,23 +1090,32 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
workload_lst = calculate_workload(global_seqlen_lst)
world_size = self.actor_rollout_wg.world_size
# Get dp_size from dispatch info to correctly balance across data parallel ranks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you encapsulate this into a function?

# Note: world_size may include tensor/pipeline parallel dimensions, but we only want DP
if "actor" in self.actor_rollout_wg._dispatch_info:
dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"]
dp_size = max(dp_rank_mapping) + 1
else:
# Query dispatch info if not yet available
dp_rank_mapping = self.actor_rollout_wg._query_dispatch_info("actor")
self.actor_rollout_wg._dispatch_info["actor"] = dp_rank_mapping
dp_size = max(dp_rank_mapping) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to get dp_rank_mapping and calculate dp_size is duplicated in both the if and else blocks. This can be refactored to be more concise and avoid repetition, improving code maintainability.

        if "actor" not in self.actor_rollout_wg._dispatch_info:
            # Query dispatch info if not yet available
            self.actor_rollout_wg._dispatch_info["actor"] = self.actor_rollout_wg._query_dispatch_info("actor")
        dp_rank_mapping = self.actor_rollout_wg._dispatch_info["actor"]
        dp_size = max(dp_rank_mapping) + 1

if keep_minibatch:
# Decouple the DP balancing and mini-batching.
minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
minibatch_num = len(workload_lst) // minibatch_size
global_partition_lst = [[] for _ in range(world_size)]
global_partition_lst = [[] for _ in range(dp_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
workload_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=world_size,
k_partitions=dp_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
else:
global_partition_lst = get_seqlen_balanced_partitions(
workload_lst, k_partitions=world_size, equal_size=True
workload_lst, k_partitions=dp_size, equal_size=True
)
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
for idx, partition in enumerate(global_partition_lst):
Expand Down
Loading