Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 96 additions & 10 deletions verl/utils/seqlen_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,50 @@
from verl.utils.device import get_device_name


def calculate_workload(seqlen_list: list[int]):
"""
Calculate the workload for a dense transformer block based on sequence length.
FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2
Hardcodes the constants by a 7B model (hidden_size=4096),
so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2).
def calculate_workload(seqlen_list: torch.Tensor) -> torch.Tensor:
"""Calculate approximate computational workload for transformer attention.

Estimates FLOPs for dense transformer blocks based on sequence length using
the formula: FLOPs ≈ 12 * hidden_size² * seqlen + 2 * hidden_size * seqlen²

The constants are calibrated for a 7B model (hidden_size=4096), yielding:
workload ∝ 24576 * seqlen + seqlen²

Args:
seqlen_list: Sequence lengths as a tensor.

Returns:
torch.Tensor: Estimated workload values proportional to actual FLOPs.

Note:
The returned values are relative workloads, not actual FLOP counts.
Useful for balancing computation across data parallel ranks.
"""
return 24576 * seqlen_list + seqlen_list**2


def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):
def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool) -> list[list[int]]:
"""Partition items into k groups using the Karmarkar-Karp differencing method.

Implements the Largest Differencing Method (LDM) algorithm for balanced
multi-way number partitioning. This heuristic produces near-optimal partitions
by iteratively combining the sets with the largest difference.

Args:
seqlen_list: Values to partition (typically sequence lengths or workloads).
k_partitions: Number of partitions to create.
equal_size: If True, each partition will have exactly len(seqlen_list) / k_partitions
items. If False, partitions may have different sizes.

Returns:
list[list[int]]: List of k partitions, each containing indices into seqlen_list.

See Also:
https://en.wikipedia.org/wiki/Largest_differencing_method

Note:
When equal_size=True, len(seqlen_list) must be divisible by k_partitions.
"""
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
class Set:
def __init__(self) -> None:
Expand Down Expand Up @@ -138,7 +171,25 @@ def __repr__(self) -> str:
return partitions


def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool):
def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool) -> list[list[int]]:
"""Partition items into k groups using a greedy assignment strategy.

Assigns each item to the partition with the smallest current sum, iterating
through items in order. Simpler but typically less optimal than Karmarkar-Karp.

Args:
seqlen_list: Values to partition (typically sequence lengths or workloads).
k_partitions: Number of partitions to create.
equal_size: If True, adds a bias to ensure equal partition sizes.
Requires len(seqlen_list) to be divisible by k_partitions.

Returns:
list[list[int]]: List of k partitions, each containing indices into seqlen_list.

Note:
When equal_size=True, a large bias is added to encourage equal distribution
of items before considering the actual values.
"""
bias = sum(seqlen_list) + 1 if equal_size else 0
sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
partitions = [[] for _ in range(k_partitions)]
Expand Down Expand Up @@ -250,11 +301,46 @@ def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], pr
}


def ceildiv(a, b):
def ceildiv(a: int, b: int) -> int:
"""Compute ceiling division of a by b.

Returns the smallest integer greater than or equal to a/b.
Uses the identity: ceil(a/b) = floor((a + b - 1) / b) = -(-a // b)

Args:
a: Dividend (numerator).
b: Divisor (denominator), must be non-zero.

Returns:
int: Ceiling of a divided by b.

Example:
>>> ceildiv(7, 3) # ceil(7/3) = ceil(2.33) = 3
3
>>> ceildiv(6, 3) # ceil(6/3) = ceil(2.0) = 2
2
"""
return -(a // -b)


def roundup_divisible(a, b):
def roundup_divisible(a: int, b: int) -> int:
"""Round up a to the nearest multiple of b.

Returns the smallest multiple of b that is >= a.

Args:
a: Value to round up.
b: Divisor to round to (must be positive).

Returns:
int: Smallest multiple of b that is >= a.

Example:
>>> roundup_divisible(7, 4) # nearest multiple of 4 >= 7 is 8
8
>>> roundup_divisible(8, 4) # 8 is already a multiple of 4
8
"""
return ((a + b - 1) // b) * b


Expand Down