Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ae3f3d9
Vectorize splitting of sequences longer than `seq_length` in BFD packing
mariosasko Dec 19, 2025
ce371f0
Minor improvements
mariosasko Feb 26, 2026
d5b5ec3
Merge with upstream
mariosasko Feb 26, 2026
b393b4e
Nit
mariosasko Feb 26, 2026
a32b5a5
More nits
mariosasko Feb 26, 2026
c216184
Merge branch 'main' into vectorized-bfd-chunking
mariosasko Mar 4, 2026
819c2f8
Merge branch 'main' into vectorized-bfd-chunking
qgallouedec Mar 7, 2026
25b0502
Improvements and fixes to make the bot happy
mariosasko Mar 9, 2026
fb6ffc9
Tests
mariosasko Mar 9, 2026
dc23b74
Merge branch 'main' of github.com:huggingface/trl into vectorized-bfd…
mariosasko Mar 9, 2026
5caa917
Nit
mariosasko Mar 9, 2026
e08d8fe
Account for possible empty sequences
mariosasko Mar 11, 2026
86e89a2
Resolve conflict
mariosasko Mar 11, 2026
a766059
Update docs/source/data_utils.md
mariosasko Mar 11, 2026
15a8367
Fix BFD formatting
mariosasko Mar 13, 2026
ccc3e8e
Merge branch 'vectorized-bfd-chunking' of github.com:mariosasko/trl i…
mariosasko Mar 13, 2026
9ae9266
Merge branch 'main' into vectorized-bfd-chunking
qgallouedec Mar 13, 2026
39ba691
Don't mutate in-place
mariosasko Mar 13, 2026
7508775
Add paper reference
mariosasko Mar 13, 2026
f47060b
Merge branch 'vectorized-bfd-chunking' of github.com:mariosasko/trl i…
mariosasko Mar 13, 2026
4356e20
drop enum; deprecate bfd-requeue; `bfd_split` -> `bfd-split``, re-add…
qgallouedec Mar 13, 2026
4737aee
paper index and revisit doc
qgallouedec Mar 13, 2026
42d4be7
`-` -> `_`
qgallouedec Mar 13, 2026
37b168d
fix word
qgallouedec Mar 13, 2026
5007bec
fix backward compt
qgallouedec Mar 13, 2026
b265bbf
better doc rendering
qgallouedec Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,22 @@ SFTConfig(
)
```

### Fewer Truncations Improve Language Modeling

**📜 Paper**: https://huggingface.co/papers/2404.10830

The paper shows that the standard concatenate-then-split preprocessing (`packing_strategy="wrapped"`) used for LLM training causes many documents to be arbitrarily truncated, which harms learning. It proposes packing document chunks into context windows using a Best-Fit Decreasing bin-packing algorithm, greatly reducing truncation while keeping high token utilization and improving model performance. TRL implements this as the `"bfd_split"` packing strategy in [`SFTConfig`]. For more details on packing, see the [SFT documentation](sft_trainer#packing).

```python
from trl import SFTConfig

training_args = SFTConfig(
packing=True,
packing_strategy="bfd_split",
max_length=4096,
)
```

### Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer

**📜 Paper**: https://huggingface.co/papers/1910.10683
Expand Down
27 changes: 20 additions & 7 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,35 @@ To help you choose an appropriate value, we provide a utility to visualize the s
[Truncation](#truncation) has several drawbacks:

1. **Loss of information**: Key data at the end of a sequence may be discarded.
2. **Choosing truncation length**: Too short loses data; too long undermines efficiency.
1. **Loss of information**: Important tokens at the end of sequences may be discarded.
2. **Choosing truncation length**: Too short loses data; too long reduces efficiency.

Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
Packing mitigates these issues by grouping multiple sequences into the same training row, filling each row up to `max_length`.

![Packing](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_3.png)

Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`].
TRL implements packing using **Best-Fit Decreasing (BFD)** bin packing, which groups sequences efficiently while minimizing padding. When a sequence exceeds `max_length`, different strategies determine how the overflow tokens are handled.

> [!TIP]
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in [`SFTConfig`].
TRL supports three strategies:

* `"bfd"` (default): Uses **Best-Fit Decreasing packing**. If a sequence exceeds `max_length`, the overflow tokens are discarded.

* `"bfd_split"`: Uses **Best-Fit Decreasing packing**, but long sequences are split into chunks ≤ `max_length` before packing. This preserves all tokens and follows the approach proposed in [Fewer Truncations Improve Language Modeling](https://huggingface.co/papers/2404.10830).

* `"wrapped"`: All tokens are concatenated into a stream and split into fixed-length blocks. This minimizes padding but may mix unrelated examples. This strategy corresponds to the *concatenate-then-split* preprocessing described in the literature (e.g., [Fewer Truncations Improve Language Modeling](https://huggingface.co/papers/2404.10830)). It has the downside of breaking sequence continuity for a large fraction of the dataset, which hurts performance, as discussed in the [Qwen3-Coder-Next Technical Report](https://huggingface.co/papers/2603.00729).

> [!NOTE]
> If all sequences are shorter than `max_length`, **`bfd` and `bfd_split` behave identically**, since no truncation or splitting is required.
```python
from trl import SFTConfig

training_args = SFTConfig(..., packing=True, max_length=512)
training_args = SFTConfig(
...,
packing=True,
packing_strategy="bfd",
max_length=512,
)
```

## PEFT for parameter-efficient fine-tuning
Expand Down
50 changes: 42 additions & 8 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,57 +1056,72 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting


class TestPackDatasetBfd(TrlTestCase):
def test_simple(self):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
expected_format = dataset.format
assert dataset.to_dict() == expected_output
assert "seq_lengths" in expected_format["columns"]
expected_format["columns"].remove("seq_lengths")
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting

def test_with_overlong_0(self):
examples = {
Expand All @@ -1118,7 +1133,7 @@ def test_with_overlong_0(self):
"input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]],
"seq_lengths": [[4], [4], [2, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_with_overlong_two_coluns(self):
Expand All @@ -1133,7 +1148,7 @@ def test_with_overlong_two_coluns(self):
"col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
"seq_lengths": [[4], [4], [3], [3], [2]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_with_non_power_of_2(self):
Expand All @@ -1146,10 +1161,10 @@ def test_with_non_power_of_2(self):
"input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]],
"seq_lengths": [[5], [4, 1], [3]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_default_no_requeue(self):
def test_default_no_split(self):
"""Test default 'bfd' strategy for SFT datasets (truncates overflow)."""
examples = {
"input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
Expand All @@ -1164,6 +1179,19 @@ def test_default_no_requeue(self):
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
assert dataset.to_dict() == expected_output

def test_with_empty_sequences(self):
examples = {
"input_ids": [[1, 2], [], [3, 4, 5], [], [6]],
}
dataset = Dataset.from_dict(examples)
seq_length = 4
expected_output = {
"input_ids": [[3, 4, 5, 6], [1, 2]],
"seq_lengths": [[3, 1], [2]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output


class TestTruncateExamples(TrlTestCase):
def test_with_dataset(self):
Expand All @@ -1172,28 +1200,34 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting

def test_with_extra_column(self):
examples = {
Expand Down
Loading