Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ae3f3d9
Vectorize splitting of sequences longer than `seq_length` in BFD packing
mariosasko Dec 19, 2025
ce371f0
Minor improvements
mariosasko Feb 26, 2026
d5b5ec3
Merge with upstream
mariosasko Feb 26, 2026
b393b4e
Nit
mariosasko Feb 26, 2026
a32b5a5
More nits
mariosasko Feb 26, 2026
c216184
Merge branch 'main' into vectorized-bfd-chunking
mariosasko Mar 4, 2026
819c2f8
Merge branch 'main' into vectorized-bfd-chunking
qgallouedec Mar 7, 2026
25b0502
Improvements and fixes to make the bot happy
mariosasko Mar 9, 2026
fb6ffc9
Tests
mariosasko Mar 9, 2026
dc23b74
Merge branch 'main' of github.com:huggingface/trl into vectorized-bfd…
mariosasko Mar 9, 2026
5caa917
Nit
mariosasko Mar 9, 2026
e08d8fe
Account for possible empty sequences
mariosasko Mar 11, 2026
86e89a2
Resolve conflict
mariosasko Mar 11, 2026
a766059
Update docs/source/data_utils.md
mariosasko Mar 11, 2026
15a8367
Fix BFD formatting
mariosasko Mar 13, 2026
ccc3e8e
Merge branch 'vectorized-bfd-chunking' of github.com:mariosasko/trl i…
mariosasko Mar 13, 2026
9ae9266
Merge branch 'main' into vectorized-bfd-chunking
qgallouedec Mar 13, 2026
39ba691
Don't mutate in-place
mariosasko Mar 13, 2026
7508775
Add paper reference
mariosasko Mar 13, 2026
f47060b
Merge branch 'vectorized-bfd-chunking' of github.com:mariosasko/trl i…
mariosasko Mar 13, 2026
4356e20
drop enum; deprecate bfd-requeue; `bfd_split` -> `bfd-split``, re-add…
qgallouedec Mar 13, 2026
4737aee
paper index and revisit doc
qgallouedec Mar 13, 2026
42d4be7
`-` -> `_`
qgallouedec Mar 13, 2026
37b168d
fix word
qgallouedec Mar 13, 2026
5007bec
fix backward compt
qgallouedec Mar 13, 2026
b265bbf
better doc rendering
qgallouedec Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import AutoProcessor, AutoTokenizer, is_vision_available

from trl.data_utils import (
PackingStrategy,
apply_chat_template,
extract_prompt,
is_conversational,
Expand Down Expand Up @@ -1056,57 +1057,82 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting


class TestPackingStrategy(TrlTestCase):
def test_aliases(self):
assert PackingStrategy("bfd-split") is PackingStrategy.BFD_SPLIT
assert PackingStrategy("bfd-truncate") is PackingStrategy.BFD

def test_missing_value_raises_value_error(self):
with pytest.raises(ValueError, match="not a valid PackingStrategy"):
PackingStrategy("missing")


class TestPackDatasetBfd(TrlTestCase):
def test_simple(self):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
expected_format = dataset.format
assert dataset.to_dict() == expected_output
assert "seq_lengths" in expected_format["columns"]
expected_format["columns"].remove("seq_lengths")
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"seq_lengths": [[4], [3, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting

def test_with_overlong_0(self):
examples = {
Expand All @@ -1118,7 +1144,7 @@ def test_with_overlong_0(self):
"input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]],
"seq_lengths": [[4], [4], [2, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_with_overlong_two_coluns(self):
Expand All @@ -1133,7 +1159,7 @@ def test_with_overlong_two_coluns(self):
"col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
"seq_lengths": [[4], [4], [3], [3], [2]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_with_non_power_of_2(self):
Expand All @@ -1146,10 +1172,10 @@ def test_with_non_power_of_2(self):
"input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]],
"seq_lengths": [[5], [4, 1], [3]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue")
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output

def test_default_no_requeue(self):
def test_default_no_split(self):
"""Test default 'bfd' strategy for SFT datasets (truncates overflow)."""
examples = {
"input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
Expand All @@ -1164,6 +1190,19 @@ def test_default_no_requeue(self):
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
assert dataset.to_dict() == expected_output

def test_with_empty_sequences(self):
examples = {
"input_ids": [[1, 2], [], [3, 4, 5], [], [6]],
}
dataset = Dataset.from_dict(examples)
seq_length = 4
expected_output = {
"input_ids": [[3, 4, 5, 6], [1, 2]],
"seq_lengths": [[3, 1], [2]],
}
dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
assert dataset.to_dict() == expected_output


class TestTruncateExamples(TrlTestCase):
def test_with_dataset(self):
Expand All @@ -1172,28 +1211,34 @@ def test_with_dataset(self):
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
dataset = dataset.with_format("numpy", dtype="float32")
format = dataset.format
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
assert dataset.to_dict() == expected_output
assert format == dataset.format

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
dataset = dataset.with_format("numpy")
formatting = dataset._formatting
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
num_examples = len(examples[next(iter(examples))])
assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output
assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
assert formatting == dataset._formatting

def test_with_extra_column(self):
examples = {
Expand Down
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure = {
"chat_template_utils": ["add_response_schema", "clone_chat_template", "get_training_chat_template"],
"data_utils": [
"PackingStrategy",
"apply_chat_template",
"extract_prompt",
"is_conversational",
Expand Down Expand Up @@ -72,6 +73,7 @@
if TYPE_CHECKING:
from .chat_template_utils import add_response_schema, clone_chat_template, get_training_chat_template
from .data_utils import (
PackingStrategy,
apply_chat_template,
extract_prompt,
is_conversational,
Expand Down
Loading
Loading