Commit eda1127
feat: add max_tokens_per_microbatch config for token-based micro-batching (NovaSky-AI#1477)
## Token-based micro-batching (`max_tokens_per_microbatch`) for FSDP +
Megatron
### Summary
Adds **token-based micro-batching**: instead of splitting a batch into
fixed *sample-count* microbatches (`micro_*_batch_size_per_gpu`),
microbatches are formed by **bin-packing samples up to a token budget**
(`max_tokens_per_microbatch`). This balances real token counts across
microbatches, improving throughput and memory utilization when sequence
lengths are highly variable. It applies to both the forward
(logprobs/values) and forward_backward (training) paths, on both the
**FSDP** and **Megatron** backends.
Along the way it fixes several correctness issues that surface on the
Megatron/THD path, and makes the pre-training "old logprobs" forward
pack identically to the training step so the PPO ratio is exact at the
first inner step.
### Config
- **`trainer.max_tokens_per_microbatch: int = -1`** — `> 0` enables
token-based batching; `-1` keeps existing sample-count behavior. It is a
*soft* cap: sequences are never split, so a single sequence longer than
the cap gets its own (over-cap) microbatch. Documented in the config
with guidance to set it `>= max_prompt_length + max_generate_length`.
- **`trainer.recompute_old_logprobs_per_minibatch: bool = True`** —
compute the policy old-logprobs (and critic values) per mini-batch,
matching the training step's mini-batch + DP partition (see
Forward/backward packing consistency below).
### How it works
`worker_utils.py` gains a small iterator hierarchy behind
`get_microbatch_iterator(...)`:
- `balanced_binpacking(token_counts, max_tokens)` — greedy least-loaded
bin-packing over per-sample token counts (from `attention_mask`).
- `TokenBasedBatchIterator` — packs samples into microbatches;
**synchronizes the microbatch count across DP ranks** by appending
loss-neutral *padding microbatches* (so every rank runs the same number
of forward passes, required by Megatron's pipeline schedule), and
reorders/strips them on the way out.
- `SampleBasedBatchIterator` / `BatchIterator` — the existing
fixed-count path, refactored under a shared `BaseBatchIterator`.
### Correctness fixes (mostly Megatron/THD-specific)
- **Multi-modal support:** `TensorList.__getitem__` now supports
tensor/list fancy indexing, so `pixel_values` / `image_grid_thw` are
gathered correctly when forming token-based microbatches.
- **Padding microbatch shape/device:** padding microbatches use the real
`seq_len` (uniform across microbatches, as Megatron requires) and are
built on the data's device — fixes a CPU/CUDA mismatch and an
`IndexError` in `postprocess_packed_seqs`.
- **Single-token dummy rows:** rows added to equalize batch size within
a microbatch use a one-hot `attention_mask` instead of all-ones — keeps
both the packed (no zero-length `cu_seqlens` segment) and dense (no
fully-masked-row NaN) paths valid while skipping nearly all of the
dummy's compute.
- **`_pad_microbatch_to_size`** consolidated and hoisted to the base
worker so ref/critic workers (which share `_forward_logprobs`) get it
too.
- **PP > 1 forward:** the token-based reorder is skipped on non-last
pipeline stages (which only return a placeholder), fixing an
`IndexError`.
- **Metric dilution:** fully-padding microbatches are excluded from
metric aggregation, so `loss_metrics/clip_ratio`, `policy_entropy`, etc.
are no longer dragged toward `0`.
- **KL/entropy normalization:** the KL/entropy term is normalized by the
count of *real* microbatches (`num_real_microbatches`) rather than the
padded total, so regularization strength isn't diluted by padding
microbatches at DP > 1.
### Forward/backward packing consistency
The pre-training forward (old logprobs) ran over the *full batch* while
training runs *per mini-batch*, so the two used different DP partitions
and THD packing — making `old_logprobs` differ from the epoch-0
recomputed logprobs (PPO ratio ≠ 1) on the packed path.
`recompute_old_logprobs_per_minibatch` makes the forward use the **same
mini-batch + DP partition** as training via a new
`WorkerDispatch.forward_from_staged` (mirroring
`forward_backward_from_staged`, reusing `stage_data`), so packing — and
thus logprobs — match exactly. The trainer helper is
`_execute_forward_pass`.
### Observability
New per-step metrics when token batching is on: **`num_microbatches`**
and **`num_padding_microbatches`** (averaged across DP), for both
backends.
### Tests
- **CPU unit tests** (`test_token_based_batching_utils.py`):
`balanced_binpacking` (incl. oversized-sequence behavior),
`TokenBasedBatchIterator`, multi-modal `TensorList` gathering,
padding-microbatch shape / `num_padding_microbatches`.
- **GPU tests** (`test_token_based_batching.py`): FSDP/Megatron
forward_backward + loss-equivalence vs sample-based (packed and dense,
DP=1 and DP=2 with padding microbatches), the new metrics, and
`test_megatron_per_minibatch_forward_matches_forward_backward`
validating that per-mini-batch forward matches training packing while
full-batch forward diverges (DP=2, packed).
---------
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>1 parent 1002394 commit eda1127
13 files changed
Lines changed: 1579 additions & 158 deletions
File tree
- .github/workflows
- skyrl
- backends/skyrl_train
- workers
- megatron
- train
- config
- dataset
- tests/backends/skyrl_train
- gpu/gpu_ci
- megatron
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
73 | | - | |
| 72 | + | |
| 73 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
75 | 81 | | |
76 | 82 | | |
77 | 83 | | |
| |||
Lines changed: 15 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
309 | 309 | | |
310 | 310 | | |
311 | 311 | | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
312 | 317 | | |
313 | 318 | | |
314 | 319 | | |
| |||
454 | 459 | | |
455 | 460 | | |
456 | 461 | | |
457 | | - | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
458 | 472 | | |
459 | 473 | | |
460 | 474 | | |
| |||
Lines changed: 240 additions & 93 deletions
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
| 54 | + | |
54 | 55 | | |
55 | 56 | | |
| 57 | + | |
56 | 58 | | |
57 | 59 | | |
58 | 60 | | |
| |||
757 | 759 | | |
758 | 760 | | |
759 | 761 | | |
760 | | - | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
761 | 767 | | |
762 | 768 | | |
763 | 769 | | |
764 | | - | |
765 | | - | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
766 | 773 | | |
767 | | - | |
| 774 | + | |
768 | 775 | | |
769 | 776 | | |
770 | 777 | | |
| |||
782 | 789 | | |
783 | 790 | | |
784 | 791 | | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
785 | 801 | | |
786 | 802 | | |
787 | 803 | | |
| |||
1023 | 1039 | | |
1024 | 1040 | | |
1025 | 1041 | | |
1026 | | - | |
1027 | | - | |
1028 | | - | |
1029 | | - | |
1030 | | - | |
| 1042 | + | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + | |
| 1046 | + | |
| 1047 | + | |
| 1048 | + | |
| 1049 | + | |
| 1050 | + | |
| 1051 | + | |
1031 | 1052 | | |
1032 | 1053 | | |
1033 | 1054 | | |
| |||
1247 | 1268 | | |
1248 | 1269 | | |
1249 | 1270 | | |
1250 | | - | |
| 1271 | + | |
| 1272 | + | |
1251 | 1273 | | |
1252 | 1274 | | |
1253 | 1275 | | |
| |||
1257 | 1279 | | |
1258 | 1280 | | |
1259 | 1281 | | |
1260 | | - | |
| 1282 | + | |
| 1283 | + | |
| 1284 | + | |
| 1285 | + | |
| 1286 | + | |
| 1287 | + | |
1261 | 1288 | | |
1262 | 1289 | | |
1263 | | - | |
1264 | | - | |
1265 | | - | |
| 1290 | + | |
| 1291 | + | |
| 1292 | + | |
| 1293 | + | |
| 1294 | + | |
| 1295 | + | |
| 1296 | + | |
| 1297 | + | |
| 1298 | + | |
| 1299 | + | |
| 1300 | + | |
| 1301 | + | |
1266 | 1302 | | |
1267 | 1303 | | |
1268 | 1304 | | |
| |||
1274 | 1310 | | |
1275 | 1311 | | |
1276 | 1312 | | |
1277 | | - | |
| 1313 | + | |
| 1314 | + | |
| 1315 | + | |
1278 | 1316 | | |
1279 | 1317 | | |
1280 | 1318 | | |
1281 | | - | |
1282 | | - | |
1283 | 1319 | | |
1284 | 1320 | | |
| 1321 | + | |
| 1322 | + | |
| 1323 | + | |
1285 | 1324 | | |
1286 | 1325 | | |
1287 | 1326 | | |
| |||
1313 | 1352 | | |
1314 | 1353 | | |
1315 | 1354 | | |
1316 | | - | |
| 1355 | + | |
| 1356 | + | |
| 1357 | + | |
| 1358 | + | |
| 1359 | + | |
1317 | 1360 | | |
1318 | 1361 | | |
1319 | 1362 | | |
| |||
1333 | 1376 | | |
1334 | 1377 | | |
1335 | 1378 | | |
| 1379 | + | |
| 1380 | + | |
1336 | 1381 | | |
1337 | 1382 | | |
1338 | 1383 | | |
| |||
1381 | 1426 | | |
1382 | 1427 | | |
1383 | 1428 | | |
1384 | | - | |
1385 | | - | |
1386 | | - | |
1387 | | - | |
1388 | | - | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
| 1434 | + | |
| 1435 | + | |
| 1436 | + | |
| 1437 | + | |
1389 | 1438 | | |
1390 | 1439 | | |
1391 | 1440 | | |
| |||
1408 | 1457 | | |
1409 | 1458 | | |
1410 | 1459 | | |
1411 | | - | |
1412 | | - | |
1413 | | - | |
1414 | | - | |
1415 | | - | |
| 1460 | + | |
| 1461 | + | |
| 1462 | + | |
| 1463 | + | |
| 1464 | + | |
| 1465 | + | |
| 1466 | + | |
| 1467 | + | |
| 1468 | + | |
1416 | 1469 | | |
1417 | 1470 | | |
1418 | 1471 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
237 | 237 | | |
238 | 238 | | |
239 | 239 | | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
240 | 285 | | |
241 | 286 | | |
242 | 287 | | |
| |||
0 commit comments