Skip to content

Commit eda1127

Browse files
erictang000claude
andauthored
feat: add max_tokens_per_microbatch config for token-based micro-batching (NovaSky-AI#1477)
## Token-based micro-batching (`max_tokens_per_microbatch`) for FSDP + Megatron ### Summary Adds **token-based micro-batching**: instead of splitting a batch into fixed *sample-count* microbatches (`micro_*_batch_size_per_gpu`), microbatches are formed by **bin-packing samples up to a token budget** (`max_tokens_per_microbatch`). This balances real token counts across microbatches, improving throughput and memory utilization when sequence lengths are highly variable. It applies to both the forward (logprobs/values) and forward_backward (training) paths, on both the **FSDP** and **Megatron** backends. Along the way it fixes several correctness issues that surface on the Megatron/THD path, and makes the pre-training "old logprobs" forward pack identically to the training step so the PPO ratio is exact at the first inner step. ### Config - **`trainer.max_tokens_per_microbatch: int = -1`** — `> 0` enables token-based batching; `-1` keeps existing sample-count behavior. It is a *soft* cap: sequences are never split, so a single sequence longer than the cap gets its own (over-cap) microbatch. Documented in the config with guidance to set it `>= max_prompt_length + max_generate_length`. - **`trainer.recompute_old_logprobs_per_minibatch: bool = True`** — compute the policy old-logprobs (and critic values) per mini-batch, matching the training step's mini-batch + DP partition (see Forward/backward packing consistency below). ### How it works `worker_utils.py` gains a small iterator hierarchy behind `get_microbatch_iterator(...)`: - `balanced_binpacking(token_counts, max_tokens)` — greedy least-loaded bin-packing over per-sample token counts (from `attention_mask`). - `TokenBasedBatchIterator` — packs samples into microbatches; **synchronizes the microbatch count across DP ranks** by appending loss-neutral *padding microbatches* (so every rank runs the same number of forward passes, required by Megatron's pipeline schedule), and reorders/strips them on the way out. - `SampleBasedBatchIterator` / `BatchIterator` — the existing fixed-count path, refactored under a shared `BaseBatchIterator`. ### Correctness fixes (mostly Megatron/THD-specific) - **Multi-modal support:** `TensorList.__getitem__` now supports tensor/list fancy indexing, so `pixel_values` / `image_grid_thw` are gathered correctly when forming token-based microbatches. - **Padding microbatch shape/device:** padding microbatches use the real `seq_len` (uniform across microbatches, as Megatron requires) and are built on the data's device — fixes a CPU/CUDA mismatch and an `IndexError` in `postprocess_packed_seqs`. - **Single-token dummy rows:** rows added to equalize batch size within a microbatch use a one-hot `attention_mask` instead of all-ones — keeps both the packed (no zero-length `cu_seqlens` segment) and dense (no fully-masked-row NaN) paths valid while skipping nearly all of the dummy's compute. - **`_pad_microbatch_to_size`** consolidated and hoisted to the base worker so ref/critic workers (which share `_forward_logprobs`) get it too. - **PP > 1 forward:** the token-based reorder is skipped on non-last pipeline stages (which only return a placeholder), fixing an `IndexError`. - **Metric dilution:** fully-padding microbatches are excluded from metric aggregation, so `loss_metrics/clip_ratio`, `policy_entropy`, etc. are no longer dragged toward `0`. - **KL/entropy normalization:** the KL/entropy term is normalized by the count of *real* microbatches (`num_real_microbatches`) rather than the padded total, so regularization strength isn't diluted by padding microbatches at DP > 1. ### Forward/backward packing consistency The pre-training forward (old logprobs) ran over the *full batch* while training runs *per mini-batch*, so the two used different DP partitions and THD packing — making `old_logprobs` differ from the epoch-0 recomputed logprobs (PPO ratio ≠ 1) on the packed path. `recompute_old_logprobs_per_minibatch` makes the forward use the **same mini-batch + DP partition** as training via a new `WorkerDispatch.forward_from_staged` (mirroring `forward_backward_from_staged`, reusing `stage_data`), so packing — and thus logprobs — match exactly. The trainer helper is `_execute_forward_pass`. ### Observability New per-step metrics when token batching is on: **`num_microbatches`** and **`num_padding_microbatches`** (averaged across DP), for both backends. ### Tests - **CPU unit tests** (`test_token_based_batching_utils.py`): `balanced_binpacking` (incl. oversized-sequence behavior), `TokenBasedBatchIterator`, multi-modal `TensorList` gathering, padding-microbatch shape / `num_padding_microbatches`. - **GPU tests** (`test_token_based_batching.py`): FSDP/Megatron forward_backward + loss-equivalence vs sample-based (packed and dense, DP=1 and DP=2 with padding microbatches), the new metrics, and `test_megatron_per_minibatch_forward_matches_forward_backward` validating that per-mini-batch forward matches training packing while full-batch forward diverges (DP=2, packed). --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1002394 commit eda1127

13 files changed

Lines changed: 1579 additions & 158 deletions

File tree

.github/workflows/gpu_skyrl_train_megatron.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,5 @@ jobs:
6969
run: |
7070
COMMIT_SHA="${{ github.event.pull_request.head.sha || github.sha }}"
7171
JOB_NAME="skyrl-train-gpu-ci-megatron-${COMMIT_SHA:0:7}-${{ github.run_id }}"
72-
anyscale job submit -f ci/anyscale_gpu_ci_skyrl_train_megatron.yaml --name "$JOB_NAME" --timeout 7000
73-
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name "$JOB_NAME" --timeout 7000
72+
anyscale job submit -f ci/anyscale_gpu_ci_skyrl_train_megatron.yaml --name "$JOB_NAME" --timeout 8000
73+
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name "$JOB_NAME" --timeout 8000

skyrl/backends/skyrl_train/training_batch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def __len__(self) -> int:
7272
def __getitem__(self, index):
7373
if isinstance(index, slice):
7474
return TensorList(self.tensors[index])
75+
if isinstance(index, torch.Tensor):
76+
if index.ndim == 0:
77+
return self.tensors[int(index)]
78+
return TensorList([self.tensors[int(i)] for i in index.tolist()])
79+
if isinstance(index, (list, tuple)):
80+
return TensorList([self.tensors[int(i)] for i in index])
7581
return self.tensors[index]
7682

7783
def to(self, device=None, dtype=None, non_blocking=False):

skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ def loss_func(logits, data):
309309
rollout_action_logprobs = data["rollout_action_logprobs"]
310310
action_mask = data.get("action_mask")
311311
num_microbatches = data.get("num_microbatches")
312+
# Number of microbatches carrying real samples (excludes fully-padding
313+
# microbatches added by token-based batching). Used to normalize the
314+
# KL/entropy terms over real microbatches only. Falls back to
315+
# num_microbatches when not provided (no padding microbatches).
316+
num_real_microbatches = data.get("num_real_microbatches", num_microbatches)
312317

313318
dp_size = mpu.get_data_parallel_world_size(with_context_parallel=False)
314319
tp_grp = mpu.get_tensor_model_parallel_group()
@@ -454,7 +459,16 @@ def loss_func(logits, data):
454459
# NOTE: The KL and entropy loss terms are not pre-scaled,
455460
# so we just average them across microbatches and DP workers.
456461
# KL and entropy use Megatron's existing microbatch and CP schedule scaling.
457-
loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term)
462+
# Megatron divides by num_microbatches (which includes fully-padding microbatches
463+
# added by token-based batching). Those padding microbatches contribute 0 to
464+
# KL/entropy, so dividing by the full count would dilute the regularization by
465+
# num_real/num_total. Scale up by num_microbatches/num_real_microbatches so the
466+
# terms are averaged over real microbatches only (no-op when there is no padding).
467+
kl_entropy_microbatch_scale = num_microbatches / max(1, num_real_microbatches)
468+
loss = (
469+
policy_loss * grad_sum_correction_factor
470+
+ (kl_loss_term - entropy_loss_term) * kl_entropy_microbatch_scale
471+
)
458472
unscaled_loss = loss / grad_sum_correction_factor
459473

460474
# Build per-sequence loss_fn_outputs with logprobs.

skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py

Lines changed: 240 additions & 93 deletions
Large diffs are not rendered by default.

skyrl/backends/skyrl_train/workers/worker.py

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151
)
5252
from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean
5353
from skyrl.backends.skyrl_train.workers.worker_utils import (
54+
BaseBatchIterator,
5455
BatchIterator,
5556
all_reduce_metrics,
57+
get_microbatch_iterator,
5658
reduce_metrics,
5759
)
5860
from skyrl.env_vars import (
@@ -757,14 +759,19 @@ def forward_backward(
757759
:class:`WorkerOutput` with per-sample ``loss_fn_outputs`` and scalar
758760
``metrics`` (all-reduced across DP).
759761
"""
760-
micro_batch_size = self.cfg.micro_train_batch_size_per_gpu
762+
microbatch_iterator = get_microbatch_iterator(
763+
data,
764+
micro_batch_size=self.cfg.micro_train_batch_size_per_gpu,
765+
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
766+
)
761767
all_metrics = defaultdict(list)
762768
all_loss_fn_outputs = [] # Handle separately from scalar metrics
763769

764-
for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False):
765-
microbatch_weight = micro_batch_size / len(data)
770+
for microbatch in microbatch_iterator:
771+
experience = BaseBatchIterator.batch_to_experience(microbatch)
772+
microbatch_weight = len(microbatch) / len(data)
766773
metrics = self._forward_backward_micro(
767-
micro_batch, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config
774+
experience, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config
768775
)
769776

770777
# Extract loss_fn_outputs before reduce_metrics (it's not a scalar metric)
@@ -782,6 +789,15 @@ def forward_backward(
782789
# Reduce across microbatches and all-reduce metrics across DP ranks
783790
# NOTE: Sum loss metrics because scaling is already applied at the advantage level
784791
result = reduce_metrics(all_metrics, sum_loss_metrics=sum_loss_metrics)
792+
793+
# Token-based batching diagnostics: total microbatches this rank ran and how many
794+
# were purely-padding (added to equalize the microbatch count across DP ranks).
795+
# Added before all-reduce so they are averaged across DP (num_microbatches is
796+
# identical on every rank; num_padding_microbatches reports the per-rank average).
797+
if self.cfg.max_tokens_per_microbatch > 0:
798+
result["num_microbatches"] = float(len(microbatch_iterator))
799+
result["num_padding_microbatches"] = float(getattr(microbatch_iterator, "num_padding_microbatches", 0))
800+
785801
dp_group = self.device_mesh.get_group("dp")
786802
result = all_reduce_metrics(result, self.strategy, group=dp_group, sum_loss_metrics=sum_loss_metrics)
787803

@@ -1023,11 +1039,16 @@ def forward(
10231039
"""
10241040
if loss_fn is None:
10251041
# Inference forward path: run in micro batches and emit per-sample logprobs.
1026-
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
1027-
outputs = []
1028-
for micro_batch in micro_batches:
1029-
outputs.append(self._forward_micro_batch(micro_batch))
1030-
output = TrainingOutputBatch.cat(outputs)
1042+
# Uses token-based micro-batching when `max_tokens_per_microbatch > 0`, otherwise
1043+
# falls back to fixed sample-count chunking. `reorder_and_combine_batches` restores
1044+
# the original sample order (and strips padding) for the token-based iterator.
1045+
microbatch_iterator = get_microbatch_iterator(
1046+
data,
1047+
micro_batch_size=self.cfg.micro_forward_batch_size_per_gpu,
1048+
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
1049+
)
1050+
outputs = [self._forward_micro_batch(micro_batch) for micro_batch in microbatch_iterator]
1051+
output = microbatch_iterator.reorder_and_combine_batches(outputs)
10311052
if output.device is not None and output.device != torch.device("cpu"):
10321053
output = output.to("cpu")
10331054
row_tensor = output["output"]
@@ -1247,7 +1268,8 @@ def forward_backward(self, data: TrainingInputBatch) -> WorkerOutput:
12471268
"""
12481269
Perform forward and backward passes for a batch, handling micro-batching internally.
12491270
1250-
The batch is split into micro batches based on micro_train_batch_size_per_gpu.
1271+
The batch is split into micro batches based on micro_train_batch_size_per_gpu,
1272+
or by token count if max_tokens_per_microbatch is configured.
12511273
Gradients accumulate across micro batches. Gradient scaling happens at optim_step.
12521274
12531275
Args:
@@ -1257,12 +1279,26 @@ def forward_backward(self, data: TrainingInputBatch) -> WorkerOutput:
12571279
:class:`WorkerOutput` with empty ``loss_fn_outputs`` and scalar
12581280
``metrics`` (all-reduced across DP).
12591281
"""
1260-
micro_batch_size = self.cfg.micro_train_batch_size_per_gpu
1282+
use_token_batching = self.cfg.max_tokens_per_microbatch > 0
1283+
microbatch_iterator = get_microbatch_iterator(
1284+
data,
1285+
micro_batch_size=self.cfg.micro_train_batch_size_per_gpu,
1286+
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
1287+
)
12611288
all_metrics = defaultdict(list)
12621289

1263-
for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False):
1264-
metrics = self._forward_backward_micro(micro_batch)
1265-
self._micro_batches_accumulated += 1
1290+
for microbatch in microbatch_iterator:
1291+
experience = BaseBatchIterator.batch_to_experience(microbatch)
1292+
1293+
if use_token_batching:
1294+
# With token-based batching, microbatches may have different sizes.
1295+
# Scale loss by microbatch_weight so gradients are correctly weighted.
1296+
microbatch_weight = len(microbatch) / len(data)
1297+
metrics = self._forward_backward_micro(experience, microbatch_weight=microbatch_weight)
1298+
else:
1299+
metrics = self._forward_backward_micro(experience)
1300+
self._micro_batches_accumulated += 1
1301+
12661302
for k, v in metrics.items():
12671303
all_metrics[k].append(v)
12681304

@@ -1274,14 +1310,17 @@ def forward_backward(self, data: TrainingInputBatch) -> WorkerOutput:
12741310

12751311
return WorkerOutput(metrics=result)
12761312

1277-
def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]:
1313+
def _forward_backward_micro(
1314+
self, experience: Experience, microbatch_weight: Optional[float] = None
1315+
) -> Dict[str, float]:
12781316
"""
12791317
Perform forward and backward pass for one micro batch.
12801318
1281-
Loss is NOT scaled here - gradient scaling happens at optim_step time.
1282-
12831319
Args:
12841320
experience: Experience object for one micro batch
1321+
microbatch_weight: If provided, scale loss by this weight before backward.
1322+
Used with token-based batching where microbatches have variable sizes.
1323+
If None, loss is unscaled (gradient scaling happens at optim_step time).
12851324
12861325
Returns:
12871326
All-reduced metrics dict for this micro batch
@@ -1313,7 +1352,11 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]:
13131352
config=self.cfg.algorithm,
13141353
loss_mask=loss_mask,
13151354
)
1316-
# NO loss scaling here - gradient scaling happens at optim_step
1355+
1356+
if microbatch_weight is not None:
1357+
# Token-based batching: scale loss by weight so gradients are properly weighted
1358+
loss = loss * microbatch_weight
1359+
# else: NO loss scaling here - gradient scaling happens at optim_step
13171360
self.strategy.backward(loss, self.model, self.optimizer)
13181361

13191362
status = {
@@ -1333,6 +1376,8 @@ def optim_step(self) -> float:
13331376
The gradient norm (before scaling, after clipping)
13341377
"""
13351378
# Scale accumulated gradients by 1/N to get correct average
1379+
# NOTE: When using token-based batching, loss is pre-scaled by microbatch_weight
1380+
# in forward_backward, so _micro_batches_accumulated stays 0 and no scaling needed.
13361381
if self._micro_batches_accumulated > 0:
13371382
scale = 1.0 / self._micro_batches_accumulated
13381383
for param in self.model.parameters():
@@ -1381,11 +1426,15 @@ def forward(self, data: TrainingInputBatch) -> WorkerOutput:
13811426
per-sample dict with key ``"values"``.
13821427
"""
13831428
# Run in micro batches and emit per-sample values.
1384-
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
1385-
outputs = []
1386-
for micro_batch in micro_batches:
1387-
outputs.append(self._forward_micro_batch(micro_batch))
1388-
output = TrainingOutputBatch.cat(outputs)
1429+
# Uses token-based micro-batching when `max_tokens_per_microbatch > 0`; otherwise fixed
1430+
# sample-count chunking. `reorder_and_combine_batches` restores original sample order.
1431+
microbatch_iterator = get_microbatch_iterator(
1432+
data,
1433+
micro_batch_size=self.cfg.micro_forward_batch_size_per_gpu,
1434+
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
1435+
)
1436+
outputs = [self._forward_micro_batch(micro_batch) for micro_batch in microbatch_iterator]
1437+
output = microbatch_iterator.reorder_and_combine_batches(outputs)
13891438
if output.device is not None and output.device != torch.device("cpu"):
13901439
output = output.to("cpu")
13911440
row_tensor = output["output"]
@@ -1408,11 +1457,15 @@ def forward(self, data: TrainingInputBatch) -> WorkerOutput:
14081457
per-sample dict with key ``"logprobs"``.
14091458
"""
14101459
# Run in micro batches and emit per-sample logprobs.
1411-
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
1412-
outputs = []
1413-
for micro_batch in micro_batches:
1414-
outputs.append(self._forward_micro_batch(micro_batch))
1415-
output = TrainingOutputBatch.cat(outputs)
1460+
# Uses token-based micro-batching when `max_tokens_per_microbatch > 0`; otherwise fixed
1461+
# sample-count chunking. `reorder_and_combine_batches` restores original sample order.
1462+
microbatch_iterator = get_microbatch_iterator(
1463+
data,
1464+
micro_batch_size=self.cfg.micro_forward_batch_size_per_gpu,
1465+
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
1466+
)
1467+
outputs = [self._forward_micro_batch(micro_batch) for micro_batch in microbatch_iterator]
1468+
output = microbatch_iterator.reorder_and_combine_batches(outputs)
14161469
if output.device is not None and output.device != torch.device("cpu"):
14171470
output = output.to("cpu")
14181471
row_tensor = output["output"]

skyrl/backends/skyrl_train/workers/worker_dispatch.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,51 @@ def forward(
237237

238238
return WorkerOutput.cat(self._actor_groups[model].actor_infos, results)
239239

240+
def forward_from_staged(
241+
self,
242+
model: str,
243+
chunk_refs: List[ObjectRef],
244+
loss_fn: Optional[str] = None,
245+
loss_fn_config: Optional[Dict[str, Any]] = None,
246+
model_id: Optional[str] = None,
247+
) -> WorkerOutput:
248+
"""Run a forward pass using pre-staged per-DP chunks.
249+
250+
Consumes per-DP chunks already placed in the object store by :meth:`stage_data`, so
251+
serialization of the per-mini-batch chunks is amortized off the dispatch critical path
252+
across mini-batches (see :meth:`forward_backward_from_staged`). The chunks are produced
253+
exactly as in :meth:`stage_data`, so the per-rank partition (and thus the microbatch packing)
254+
matches what ``forward_backward`` sees for the same mini-batch.
255+
256+
Args:
257+
model: Model identifier ("policy", "critic", or "ref")
258+
chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_data``)
259+
loss_fn: Optional resolved loss function name. When set, the worker computes
260+
loss + per-sample outputs without backward (no_grad).
261+
loss_fn_config: Optional config overrides for the loss function.
262+
model_id: Optional Tinker model_id; selects the LoRA adapter before the forward.
263+
264+
Returns:
265+
:class:`WorkerOutput` aggregated across DP ranks.
266+
"""
267+
self._ensure_on_gpu(model, need_optimizer=False, need_model=True)
268+
self.ensure_active_adapter(model, model_id)
269+
270+
kwargs = {}
271+
if loss_fn is not None:
272+
kwargs["loss_fn"] = loss_fn
273+
if loss_fn_config is not None:
274+
kwargs["loss_fn_config"] = loss_fn_config
275+
276+
refs = MeshDispatch.dispatch_from_staged(
277+
self._actor_groups[model].actor_infos,
278+
"forward",
279+
chunk_refs=chunk_refs,
280+
**kwargs,
281+
)
282+
results = ray.get(refs)
283+
return WorkerOutput.cat(self._actor_groups[model].actor_infos, results)
284+
240285
def stage_data(
241286
self,
242287
model: str,

0 commit comments

Comments
 (0)