Skip to content

Commit 2b9b705

Browse files
feat(training): add --balance-by-flops for FLOPs-aware micro-batch partitioning (#44)
Port of THUDM/slime#2017. Adds --balance-by-flops flag that replaces token-count KK balancing with FLOPs-weighted KK for both DP rank assignment (_split_train_data_by_dp) and micro-batch packing (get_data_iterator). Uses the existing calculate_fwd_flops() which accounts for the full model architecture (MoE, LoRA, attention projections) rather than the simplified coeff*L+L² from upstream. Requires --use-dynamic-batch-size.
1 parent ee4dc5d commit 2b9b705

4 files changed

Lines changed: 38 additions & 2 deletions

File tree

miles/backends/training_utils/data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn.functional as F
88

99
from miles.utils.data import get_minimum_num_micro_batch_size
10+
from miles.utils.flops_utils import calculate_workloads
1011
from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions
1112
from miles.utils.types import RolloutBatch
1213

@@ -412,7 +413,11 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices=
412413
for i, num_mbs in enumerate(num_microbatches):
413414
start, end = i * num_local_gbs, (i + 1) * num_local_gbs
414415
samples = rollout_data["total_lengths"][start:end]
415-
partitions = get_seqlen_balanced_partitions(samples, num_mbs, equal_size=False)
416+
if getattr(args, "balance_by_flops", False):
417+
weights = calculate_workloads(samples, args)
418+
partitions = get_seqlen_balanced_partitions(weights, num_mbs, equal_size=False)
419+
else:
420+
partitions = get_seqlen_balanced_partitions(samples, num_mbs, equal_size=False)
416421
for j in range(num_mbs):
417422
for k in range(len(partitions[j])):
418423
partitions[j][k] += start

miles/ray/rollout.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function
2828
from miles.utils import dumper_utils, tracking_utils
2929
from miles.utils.environ import enable_experimental_rollout_refactor
30+
from miles.utils.flops_utils import calculate_workloads
3031
from miles.utils.health_monitor import RolloutHealthMonitor
3132
from miles.utils.http_utils import (
3233
_wrap_ipv6,
@@ -856,7 +857,11 @@ def _stat(xs):
856857
total_lengths = [len(t) for t in data["tokens"]]
857858
data["total_lengths"] = total_lengths
858859

859-
if self.args.balance_data:
860+
balance_by_flops = getattr(self.args, "balance_by_flops", False)
861+
if balance_by_flops:
862+
workloads = calculate_workloads(total_lengths, self.args)
863+
partitions = get_seqlen_balanced_partitions(workloads, dp_size, equal_size=True)
864+
elif self.args.balance_data:
860865
partitions = get_seqlen_balanced_partitions(total_lengths, dp_size, equal_size=True)
861866
else:
862867
partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)]

miles/utils/arguments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ def add_data_arguments(parser):
685685
),
686686
)
687687

688+
parser.add_argument(
689+
"--balance-by-flops",
690+
action="store_true",
691+
default=False,
692+
help=(
693+
"Use FLOPs-based workload estimation for DP rank assignment and micro-batch partitioning "
694+
"via Karmarkar-Karp instead of token-count balancing. FLOPs are computed from the full "
695+
"model config (hidden_size, ffn_hidden_size, MoE experts/topk, LoRA ranks) via "
696+
"calculate_fwd_flops, capturing the quadratic cost of attention. Produces more balanced "
697+
"micro-batches when sequence lengths vary widely. Requires --use-dynamic-batch-size."
698+
),
699+
)
700+
688701
parser.add_argument(
689702
"--use-dynamic-batch-size",
690703
action="store_true",
@@ -1956,6 +1969,9 @@ def miles_validate_args(args):
19561969
if args.log_probs_max_tokens_per_gpu is None:
19571970
args.log_probs_max_tokens_per_gpu = args.max_tokens_per_gpu
19581971

1972+
if getattr(args, "balance_by_flops", False):
1973+
assert args.use_dynamic_batch_size, "--balance-by-flops requires --use-dynamic-batch-size"
1974+
19591975
if args.eps_clip_high is None:
19601976
args.eps_clip_high = args.eps_clip
19611977

miles/utils/flops_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,13 @@ def calculate_fwd_flops(
125125
total_flops += calculate_lm_head_flops(seqlen, hidden_size, vocab_size)
126126

127127
return total_flops
128+
129+
130+
def calculate_workloads(seqlens, args):
131+
"""Return per-sequence forward FLOPs for Karmarkar-Karp balancing weights.
132+
133+
One workload value per sequence length, capturing the quadratic cost of
134+
attention plus the model architecture (MoE, LoRA, attention projections).
135+
Used by ``--balance-by-flops`` for DP rank assignment and micro-batch packing.
136+
"""
137+
return [calculate_fwd_flops([sl], args) for sl in seqlens]

0 commit comments

Comments
 (0)