Releases: huggingface/trl
v0.26.1
What's Changed
- Fix vLLM error for tools usage not supported when running GRPO training by @apalmas-saifh in #4663
- Fix GRPO config validation in case
num_generations_evalis specified and different thannum_generationsby @apalmas-saifh in #4682
New Contributors
- @apalmas-saifh made their first contribution in #4663
Full Changelog: v0.26.0...v0.26.1
v0.26.0
Features
🕵️♂️ GRPO: Agent training
GRPOTrainer now supports training agents using tools. This allows language models to interact with external functions or APIs during training.
from datasets import Dataset
from trl import GRPOTrainer
def multiply(a: int, b: int) -> int:
"""
Multiplies two integers.
Args:
a: The first integer.
b: The second integer.
Returns:
The product of the two integers.
"""
return a * b
dataset = Dataset.from_list(
[
{"prompt": [{"role": "user", "content": "What is 3 multiplied by 4?"}], "answer": 12},
{"prompt": [{"role": "user", "content": "Calculate 7 times 8."}], "answer": 56},
{"prompt": [{"role": "user", "content": "Find the product of 5 and 6."}], "answer": 30},
{"prompt": [{"role": "user", "content": "What do you get when you multiply 9 by 9?"}], "answer": 81},
{"prompt": [{"role": "user", "content": "Compute 12 multiplied by 11."}], "answer": 132},
{"prompt": [{"role": "user", "content": "What is 15 times 14?"}], "answer": 210},
]
)
def accuracy(completions, answer, **kwargs):
predictions = [completion[-1]["content"] for completion in completions]
rewards = [float(str(ans) in pred) for pred, ans in zip(predictions, answer)]
return rewards
trainer = GRPOTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=dataset,
tools=[multiply],
reward_funcs=accuracy,
)
trainer.train()by @qgallouedec in #4300
ScaleRL: Add CISPO Loss
CISPO Loss was first introduced in the Minimax-M1 paper, the ScaleRL paper subsequently showed that CISPO loss scales the best in terms of performance and efficiency as models are trained for longer.
GRPOTrainer now supports the CISPO loss using loss_type="cispo" in the GRPOConfig.
by @pramodith in #4495
Add vLLM quantization option for colocate
When the input model is quantized using bitsandbytes, vLLM will now also use quantization when in colocate mode.
by @sergiopaniego in #4496
Reasoning reward
TRL nows includes a reasoning reward function
from trl.rewards import reasoning_accuracy_reward
solutions = [r"\frac{1}{3}", r"\frac{1}{3}", r"\frac{1}{3}"]
completions = [
[
{
"role": "assistant",
"content": r"<think> Reasoning content </think> The final answer is \boxed{\frac{1}{3}}",
}
],
[
{
"role": "assistant",
"content": r"<think> Reasoning content </think> The final answer is \boxed{\frac{1}{2}}",
}
],
[
{
"role": "assistant",
"content": r"<think> Reasoning content with partial answers \boxed{\frac{1}{3}} but no final answer",
}
],
]
reasoning_accuracy_reward(completions, solutions) # [1.0, 0.0, 0.0] As any other reward function, it can be used in GRPOTrainer or RLOOTrainer.
from trl import GRPOTrainer
from trl.rewards import reasoning_accuracy_reward
trainer = GRPOTrainer(
...,
reward_funcs=reasoning_accuracy_reward,
)Add shuffle_dataset option to SFTTrainer
You can now shuffle the dataset in SFTTrainer by setting the shuffle_dataset argument to True in SFTConfig. This is useful when the dataset features high similarity between consecutive samples.
from trl import SFTTrainer, SFTConfig
SFTConfig(shuffle_dataset=True)by @qgallouedec in #4564
Add SAPO Loss in GRPO
Soft Adaptive Policy Optimization (SAPO), replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful learning signals. Compared with GSPO and GRPO, SAPO is both sequence-coherent and token-adaptive. Like GSPO, SAPO maintains sequence-level coherence, but its soft gating forms a continuous trust region that avoids the brittle hard clipping band used in GSPO.
You can now use SAPO loss in GRPOTrainer by setting loss_type="sapo" in the GRPOConfig.
by @pramodith in #4600
Other Features
- Support completion bootstrap for VLM in GRPO/RLOO by @SolarWindRider in #4452
- Add support for images inside tables with Trackio completions logging by @taha-yassine in #4505
- Add step time metric to GRPO Trainer for performance tracking by @qgallouedec in #4516
- Add target_parameters to LoraConfig by @jonnyli1125 in #4536
- [SFT] Log mean token accuracy from Liger kernel by @kashif in #4302
- Add
num_generations_evalparameter for efficient evaluation by @mingxuetian in #4458 - [GRPO] Sequence-level TIS & MIS by @LeonEricsson in #4530
- TRL supports vLLM 0.11 by @qgallouedec in #4633
- feat: implement DeepSeek unbiased KL estimator for GRPO by @jlcanta in #4638
Experimental
- Move XPOTrainer to trl.experimental.xpo by @behroozazarkhalili in #4485
- Move judges to experimental submodule by @behroozazarkhalili in #4439
- Add MiniLLM Trainer by @t1101675 in #4504
- refactor: Move CPOTrainer to experimental module by @behroozazarkhalili in #4470
- Move GKDTrainer to experimental module by @behroozazarkhalili in #4474
- Move NashMDTrainer to experimental module by @behroozazarkhalili in #4477
- Move PPOTrainer to trl.experimental.ppo by @behroozazarkhalili in #4482
- [ORPO] Move ORPOTrainer to experimental by @behroozazarkhalili in #4480
- Move PRMTrainer to trl.experimental.prm by @behroozazarkhalili in #4483
- Move OnlineDPOTrainer to experimental module by @behroozazarkhalili in #4473
- Move
WinRateCallbackto experimental by @qgallouedec in #4558 - Move tests for GSPOTokenTrainer to experimental by @qgallouedec in #4572
- Raise FutureWarning for classes moved to experimental by @albertvillanova in #4605
- Move MergeModelCallback to experimental by @qgallouedec in #4608
- Raise FutureWarning for trainer moved to experimental by @albertvillanova in #4620
- Remove no longer applicable warning once BCO was moved to experimental by @albertvillanova in #4628
- Refactor suppression of warning at experimental import by @albertvillanova in #4629
- 🚚 Move KTO to trl.experimental by @neha222222 in #4575
Fixes
- Buffer samples based on group level stds. by @pramodith in #4492
- Fix bugs in CISPO conditions by @pramodith in #4499
device_mapanddtypeto"auto"by default by @qgallouedec in #4509- MiniLLM: Fix arguments in config & add to documentation index by @t1101675 in #4518
- [Bug Fix] OnlineDPOTrainer with vLLM Server Mode by @YangKai0616 in #4500
- Rename
flash-attntoflash-attn2by @qgallouedec in #4514 - fix(GOLDTrainer): Resolve incorrect attribute access and VLLMClient.generate() output type by @fabio-sim in #4526
- Fix bug with VLM processors in prompt-completion completion text-only training by @kschwethelm in #4553
- fix+docs:
device_map=Nonefor DeepSpeed and add ZeRO paper (1910.02054) to Paper Index by @JenWei0312 in #4551 - Fix vLLM sleep mode: add collective RPC call to reload weights in vLLM wake-up process by @qgallouedec in #4571
- fix: use shift_labels for metrics when using CP or SP by @jue-jue-zi in #4579
- Fix 'generation_config' AttributeError by @albertvillanova in #4596
- Fix FSDP2 model key miss match when sync LoRA model to vLLM server by @Xiao-Chenguang in #4603
- Fix KTOTrainer CUDA error for large-vocab models via tensor indexing by @bhuvanprakash in #4635
Documentation and Examples
- docs: Add PEFT subsection to reducing memory usage guide by @behroozazarkhalili in #4430
- [DOCS] update and fix openenv by @burtenshaw in #4490
- Fix link to OpenEnv docs by @lukehinds in #4502
- Tweak description for vLLM sleep mode by @lewtun in #4506
- Paper Index: Change
num_completionstonum_generationsby @pramodith in https://gi...
v0.25.1
What's Changed
- Replace accelerate logging with stdlib in CLI by @lewtun in #4512
- Add temporary workaround for
lr_scheduler_kwargsdtype issue in Transformers 4.57.0 by @qgallouedec in #4513
Full Changelog: v0.25.0...0.25.1
v0.25.0
Features
- 💤 Switch to sleep level=2 and split wake-ups in GRPO and RLOO trainers by @xxrjun in #4296
- Added custom
prepare_model_for_kbit_trainingto save VRAM by @sergiopaniego in #4335 - Add
add_generation_promptto processor_kwargs in GRPO and RLOO trainer by @qgallouedec in #4361 - Add support for Trackio completions logging in GRPOTrainer by @taha-yassine in #4359
- Support chat_template_kwargs by @pramodith in #4350
- GRPO: ScaleRL -> Support casting LM Head to FP32 by @pramodith in #4303
- Support casting to fp32 when word embeddings are tied to lm_head by @pramodith in #4446
- 💬 Add chat to vLLM client and server, update trainer calls by @qgallouedec in #4450
Experimental
- 🚚 Move BCO to
trl.experimentalby @qgallouedec in #4312 - 👑 [experimental] GOLD Trainer by @kashif in #4349
- Add PAPOTrainer for preference-based optimization by @SolarWindRider in #4334
- [GFPO] fix the GFPO loss calculation error caused by unmodified old_per_token_logps by @Peter-Chou in #4454
- 🕹️ Add rollout function for OpenEnv integration by @lewtun in #4310
Fixes
- [Activation-checkpointing] add tensor dedup and param offloading by @kashif in #4247
- Fix attn_implementation name in OnlineDPO for transformers v5 by @albertvillanova in #4322
- Hotfix: Fall back to config.text_config._name_or_path if missing config._name_or_path by @albertvillanova in #4324
- Fix GRPO and RLOO trainers for continuous batching by @albertvillanova in #4348
- Fix:
add_generation_prompt=Truefor conversational only by @qgallouedec in #4362 - Remove ignored max_length parameter from PRMTrainer data collator by @albertvillanova in #4355
- Fix add_generation_prompt arg for paged transformers in GRPO and RLOO trainers by @albertvillanova in #4370
- Fix GKD Liger memory spike by @qgallouedec in #4140
- Fix GRPO with replay buffer by inserting images in the prompt by @albertvillanova in #4391
- fix: Remove chat template setting from non-SFT trainer scripts by @behroozazarkhalili in #4437
- 🖼️ Fix reporting images with vLLM by @qgallouedec in #4476
Documentation and Examples
- Added SFT LoRA notebook by @sergiopaniego in #4244
- Update notebooks README with latest additions by @sergiopaniego in #4316
- Add notebooks to Examples docs and restructure by @sergiopaniego in #4317
- Highlight OpenEnv in landing docs by @sergiopaniego in #4327
- Update OpenEnv docs by @sergiopaniego in #4328
- Add OpenEnv blog to landing by @sergiopaniego in #4333
- 🗞️ Update "What's New" by @qgallouedec in #4338
- Update Reducing Memory Consumption guide with more details by @sergiopaniego in #4332
- Fixed links inside Tips in docs by @sergiopaniego in #4360
- 🔥 docs: Add RapidFire AI integration guide by @kamran-rapidfireAI in #4340
- Fix paper link for "Towards Efficient and Exact Optimization of Language Model Alignment" by @qgallouedec in #4409
- Migrate experimental trl feature docs by @ethanknights in #4411
- Update SFT QLoRA notebook with 14B model on free Colab by @sergiopaniego in #4336
- Create "Talks" subsection by @sergiopaniego in #4414
- Openenv wordle example by @burtenshaw in #4357
- docs: Remove outdated conversational dataset conversion guidance by @behroozazarkhalili in #4422
- docs: List all trainers that support Liger Kernel by @behroozazarkhalili in #4432
- Add On-Policy Distillation from thinking labs to paper index. by @pramodith in #4410
- Upload notebook with T4 selected by @sergiopaniego in #4449
- Removed outdated warning about batch contamination by @Harras3 in #4423
- Removed Sentiment Tuning Examples by @Harras3 in #4424
- docs: Remove outdated notebooks by @behroozazarkhalili in #4435
- docs: Move Multi-Adapter RL section to PEFT integration by @behroozazarkhalili in #4436
- Update
max_lengthexplanation for VLM in online trainers by @sergiopaniego in #4220 - Updated OpenEnv docs by @sergiopaniego in #4418
- add llasa-tutorial by @Deep-unlearning in #4456
Deprecations
- Replace deprecated AutoModelForVision2Seq with AutoModelForImageTextToText by @albertvillanova in #4353
- Replace deprecated list with tuple indexing in PPOTrainer by @albertvillanova in #4356
- Remove liger loss in favor of liger kernel by @sergiopaniego in #4364
- 🐍 Drop Python 3.9 by @qgallouedec in #4183
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #4293
- Update links to docs in README to latest packaged version by @sergiopaniego in #4084
- 🧺 [4/N] Refactor
_generatein GRPO/RLOO: Moveforward_kwargsoutside generation method by @qgallouedec in #4154 - Fix missing CI slow tests: ImportError: vLLM is not installed by @albertvillanova in #4304
- Added SFT LoRA notebook by @sergiopaniego in #4244
- ⚰️ Remove deprecated by @qgallouedec in #4301
- Silence TRL experimental warnings in CI by @albertvillanova in #4307
- Filter expected setup_chat_format deprecation warning in CI by @albertvillanova in #4306
- [Activation-checkpointing] add tensor dedup and param offloading by @kashif in #4247
- Remove parameterized as test extra dependency by @albertvillanova in #4315
- Update notebooks README with latest additions by @sergiopaniego in #4316
- 🚚 Move BCO to
trl.experimentalby @qgallouedec in #4312 - 🧺 [5/N] Refactor
_generatein GRPO/RLOO: Insert images in the prompt by @qgallouedec in #4155 - 💤 Switch to sleep level=2 and split wake-ups in GRPO and RLOO trainers by @xxrjun in #4296
- Replace unittest skipTest from transformers with pytest.skip by @albertvillanova in #4297
- Add notebooks to Examples docs and restructure by @sergiopaniego in #4317
- Fix attn_implementation name in OnlineDPO for transformers v5 by @albertvillanova in #4322
- 🕹️ Add rollout function for OpenEnv integration by @lewtun in #4310
- Highlight OpenEnv in landing docs by @sergiopaniego in #4327
- Update OpenEnv docs by @sergiopaniego in #4328
- Move BCO tests to tests/experimental by @albertvillanova in #4326
- Hotfix: Fall back to config.text_config._name_or_path if missing config._name_or_path by @albertvillanova in #4324
- Add OpenEnv blog to landing by @sergiopaniego in #4333
- 🗞️ Update "What's New" by @qgallouedec in #4338
- Update Reducing Memory Consumption guide with more details by @sergiopaniego in #4332
- Added custom
prepare_model_for_kbit_trainingto save VRAM by @sergiopaniego in #4335 - [vllm] update comment about communication group host ip by @kashif in #4337
- Fix GRPO and RLOO trainers for continuous batching by @albertvillanova in #4348
- Fixed links inside Tips in docs by @sergiopaniego in #4360
- Fix CI issue for vlm_gemma_3n model by @kaixuanliu in #4278
- Add
add_generation_promptto processor_kwargs ...
v0.24.0
Features
- Add accuracy reward by @pramodith in #4270
- Add support for
token_type_idsinDPOTrainerby @aweers in #4285 - 💰
RichProgressCallbackenhancement by @qgallouedec in #4245 - Include
chat_template_kwargsinapply_chat_templateby @cmpatino in #4233 - 🏷️ Account for
token_type_idsinDataCollatorForVisionLanguageModelingby @qgallouedec in #4190 - 🎨 Support mixing image+text and text-only examples by @qgallouedec in #4203
- 🎁
RewardTrainerrefactor by @qgallouedec in #4093 - 🎞️ Support sequence classification models in
clone_chat_templateby @qgallouedec in #4097 - ✨ Add logging for training completion and model saving in training scripts by @qgallouedec in #4048
- 🖨️ Print rich table for messages by @qgallouedec in #4160
- 😴 Add
vllm_enable_sleep_modeto RLOO Trainer by @sergiopaniego in #4107 - 📽 Multi image support for GRPO/RLOO by @qgallouedec in #4113
- 👁️ Add VLM support to RLOO trainer by @behroozazarkhalili in #4067
- ℹ️ Enable XPU for vLLM client by @jiqing-feng in #4031
- 🧶 feat: Add WeaveCallback for W&B Weave integration by @parambharat in #4089
Fixes
- [Online-DPO] fix the completion_len == max_new_tokens crash by @kashif in #4193
- Fix entropy and accuracy calculation for prompt_tuning techniques. by @pramodith in #4196
- Fix prompt-completion labeling with add_generation_prompt and warning by @behroozazarkhalili in #4201
- 🌡️ Have vLLM return processed (temperature scaled) log probs by @YonatanGideoni in #4163
- Fix handling of f_divergence_type in DPO by @albertvillanova in #4171
- ⚡ Fix Flash Attention x Padding-Free loss by @qgallouedec in #4170
- Pass required token_type_ids by @albertvillanova in #4148
- 👩🦯 Fix usage of VLM using text only by @SamuelBarryCS in #4080
- ⚓ [vllm] ensure MASTER_ADDR/MASTER_PORT are set safely by @kashif in #4057
- 📤 Fix a dataset loading bug in scripts by @singing-cat in #4124
- 🐯 fix: use_liger_kernel with IterableDataset by @jue-jue-zi in #4087
- [GKD] Fix
batchmeanreduce op in GKDTrainer's loss by @cmpatino in #4105 - Fix get_peft_model() so that prepare_model_for_kbit_training does not reapply to an instance of PeftModel, thus freezing all the layers by @Hoesu in #4081
- Aux loss is already included in the loss returned by Transformers by @pramodith in #4078
- ♨️ [GRPO] Fix potential hang in
get_high_entropy_maskby @akakakakakaa in #4041
Documentation
- Remove logging.md: trainer-specific metrics documentation by @behroozazarkhalili in #4269
- Remove using_llama_models.md: outdated Llama2-specific documentation by @behroozazarkhalili in #4268
- Remove how_to_train.md: outdated training FAQ by @behroozazarkhalili in #4267
- Add Qwen3-VL notebooks (SFT, GRPO) by @sergiopaniego in #4275
- Remove obsolete research_projects directory by @behroozazarkhalili in #4243
- Add Efficient Online Training with GRPO and vLLM in TRL to community tutorials by @sergiopaniego in #4219
- Add trainers taxonomy to docs by @sergiopaniego in #4195
- Updated vLLM integration guide by @sergiopaniego in #4162
- [DOCS] Lora without regret by @burtenshaw in #4181
- Add docstring for OnlineTrainerState by @albertvillanova in #4166
- ⚖️ Align SFT and DPO for model creation and deprecate
DPOConfig.padding_valuein favour orpad_token_idby @qgallouedec in #4006 - 🏞️ Context Parallelism benchmark guide by @sergiopaniego in #4075
▶️ Add video to community tutorials by @qgallouedec in #4090- Reviewed HF jobs updated docs by @sergiopaniego in #4088
Deprecations
- Deprecate
BestOfNSamplerby @qgallouedec in #4291 - Raise deprecation warning for Python 3.9 by @albertvillanova in #4226
- Deprecate unused dataset_formatting module by @behroozazarkhalili in #4242
- Warnings pointing to RFC by @qgallouedec in #4224
🅰️ Remove apex by @qgallouedec in #4139- 🗑️ Remove deprecated
AlignPropTrainer,DDPOTrainerandIterativeSFTTrainerby @qgallouedec in #4068
Experimental
- 🧪 Add
trl.experimentalSubmodule by @August-murr in #4073 - [GRPO]: Sample from a Replay Buffer To Substitute Groups with 0 std. by @pramodith in #4060
- 🪙 [Experimental] Support GSPO-token by @hjh0119 in #3820
- 🌪️ [GFPO]: implement GFPO in GRPOTrainer by @Peter-Chou in #3989
- 🌾 [Experimental] BEMA for ref model by @qgallouedec in #3898
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #4054
- Remove redundant 'None' from docstrings by @albertvillanova in #4058
- Hotfix: Add ParallelismConfig fallback for transformers with old accelerate by @albertvillanova in #4063
- Fix CI failure in slow GRPO test due to missing pillow dependency by @albertvillanova in #4064
- 💡 Fix type hint to
make_parserfunction in multiple scripts by @qgallouedec in #4050 - Improve docstring of AlignPropTrainer by @albertvillanova in #4059
- ♨️ [GRPO] Fix potential hang in
get_high_entropy_maskby @akakakakakaa in #4041 - Set Ruff src for first-party imports by @albertvillanova in #4074
- 🧪 Add
trl.experimentalSubmodule by @August-murr in #4073 - 🌾 [Experimental] BEMA for ref model by @qgallouedec in #3898
- ✂️ [GRPO VLM] Update split sizes to generalize by @zucchini-nlp in #4032
- 🛠️ Fix CI by @qgallouedec in #4076
- 🐳 Docker update + Simplify Jobs doc by @qgallouedec in #3931
- Aux loss is already included in the loss returned by Transformers by @pramodith in #4078
- Reviewed HF jobs updated docs by @sergiopaniego in #4088
- 🗑️ Remove deprecated
AlignPropTrainer,DDPOTrainerandIterativeSFTTrainerby @qgallouedec in #4068 ▶️ Add video to community tutorials by @qgallouedec in #4090- Align slow tests with regular tests by @albertvillanova in #4085
- Add support for testing experimental features by @albertvillanova in #4082
- Community Tutorials design adaptation for videos by @sergiopaniego in #4095
- 🏞️ Context Parallelism benchmark guide by @sergiopaniego in #4075
- ⌨️ Pin num2words by @lewtun in #4094
- Add deprecation warnings to docstrings by @albertvillanova in #4083
- 📜 Convert
settolistof tags by @qgallouedec in #4092 - 🧶 feat: Add WeaveCallback for W&B Weave integration by @parambharat in #4089
- ⚖️ Align SFT and DPO for model creation and deprecate
DPOConfig.padding_valuein favour orpad_token_idby @qgallouedec in #4006 - 🌪️ [GFPO]: implement GFPO in GRPOTrainer by @Peter-Chou in #3989
- ℹ️ feat: Add NPU and XPU support for activation offloading by @zilongzheng in #4056
- ℹ️ Enable XPU for vLLM client by @jiqing-feng in #4031
- Fix get_peft_model() so that prepare_model_for_kbit_training does not reapply to an instance of PeftModel, thus freezing all the layers by @Hoesu in https://github.com/huggingface/trl/pull/...
v0.23.1
What's Changed
- ♨️ [GRPO] Fix potential hang in
get_high_entropy_maskby @akakakakakaa in #4041 - Aux loss is already included in the loss returned by Transformers by @pramodith in #4078
- Fix get_peft_model() so that prepare_model_for_kbit_training does not reapply to an instance of PeftModel, thus freezing all the layers by @Hoesu in #4081
- 🐯 fix: use_liger_kernel with IterableDataset by @jue-jue-zi in #4087
- [SFTrainer]: Fix DFT Loss by @pramodith in #4112
- ⚡ Fix Flash Attention x Padding-Free loss by @qgallouedec in #4170
New Contributors
Full Changelog: v0.23.0...v0.23.1
v0.23.0
Major
🥓 Context Parallelism
SFT now supports Context Parallelism (CP) for training large language models on very large sequences. You can now train with an arbitrarily long sequence length.
🧨 Dynamic Fine-Tuning
Dynamic Fine-Tuning (DFT) is a nnow supported in TRL.
from trl import SFTConfig
training_args = SFTConfig(
loss_type="dft",
...
)
by @qgallouedec in #4042
🪵 Truncated Importance Sampling (TIS) to address rollout-training mismatch
Different implementations are used for rollout generation (vLLM) and model training. The implementation gap implicitly turns the on-policy RL to be off-policy. Truncated Importance Sampling (TIS) a simple yet effective importance sampling technique for handling such discrepancy. This is now implemented in GRPO.
from trl import GRPOConfig
training_args = GRPOConfig(
...
use_vllm=True,
vllm_importance_sampling_correction=True, # default True
vllm_importance_sampling_cap=2.0, # hyper-parameter C
)by @LeonEricsson in #3867
🥣 [SFTTrainer]: Add Aux Loss for MoE models
Mixture of Experts (MoE) models require an auxiliary loss to ensure that the different experts are used evenly. This auxiliary loss is now supported in SFTTrainer.
training_args = SFTConfig(
model_init_kwargs={"output_router_logits": True},
...
)by @pramodith in #4012
💤 [GRPO/RLOO] Adds an option to sleep vllm when running in colocated mode
When running GRPO (or RLOO) with vLLM in colocated mode, the vLLM server consume VRAM during optimization while not being used. We now have an option to put the vLLM server to sleep during optimization to free up VRAM.
from trl import GRPOConfig
training_args = GRPOConfig(..., vllm_sleep_enabled=True)by @edbeeching in #3968
⚖️ Add vLLM server mode and VLM support to OnlineDPOTrainer
You can now use vLLM server mode with OnlineDPOTrainer. Additionally, VLM models are now supported.
Comprehensive Paper Index Enhancement with 9 New Algorithm Implementations
The paper index has been significantly enhanced with the addition of 9+ new algorithm implementations, providing a more comprehensive resource for users.
by @behroozazarkhalili in #3990
Other Notable Changes
- 👷 Added Kernels on the Hub x TRL guide by @sergiopaniego in #3969
- 🌵 Refactor entropy_from_logits for memory efficiency by @qgallouedec in #4013
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #3978
- 👮 Fix GRPO CLI by setting parameters for
get_soft_overlong_punishmentby @qgallouedec in #3972 - 🪃
args.gradient_checkpointing = Falseinstead ofargs = dataclasses.replace(args, gradient_checkpointing=False)by @qgallouedec in #3981 - [GRPO] Adds an option to sleep vllm when running in colocated mode by @edbeeching in #3968
- 🎯 Add Trackio integration documentation and update TOC by @qgallouedec in #3971
- ⚖️ Fix scale_rewards issue in GRPO by @Peter-Chou in #3992
- ⏰ fix: add return to shift_tokens_right by @ginkyenglee in #3987
- Add pre-commit and hf-doc-builder as dev dependencies by @albertvillanova in #3993
- [GRPO] Truncated Importance Sampling to address rollout-training mismatch by @LeonEricsson in #3867
- Fixed tags shown problem in memory usage docs by @sergiopaniego in #3999
- ✖️ Support pad-to-multiple-of and padding-free by @qgallouedec in #3996
- 💾 [bugfix] fix PPO save_checkpoint by @hjh0119 in #3998
- [GRPO]: Fix Multi-GPU training for Entropy based masking of tokens. by @pramodith in #3964
- 📏
torch_dypetodtypeeverywhere by @sergiopaniego in #4000 - Comprehensive Paper Index Enhancement with 9 New Algorithm Implementations by @behroozazarkhalili in #3990
- [SFT] fix: collator docstring by @LeonEricsson in #4011
- 👷 Added Kernels on the Hub x TRL guide by @sergiopaniego in #3969
- 🌵 Refactor entropy_from_logits for memory efficiency by @qgallouedec in #4013
- [SFTTrainer]: Add Aux Loss for MoE models. by @pramodith in #4012
- Add missing doc strings in SFTrainer by @pramodith in #4003
- ⚖️ Add vLLM server mode and VLM support to OnlineDPOTrainer by @vaelev in #3783
- Fix typo in GRPO quickstart by @dwisdom0 in #4020
- Align docstring parameters with function definitions by @albertvillanova in #4017
- Fix formatting errors in docstrings by @albertvillanova in #4025
- [doc] Paper index for Truncated Importance Sampling by @LeonEricsson in #4026
- [doc] Group paper index by trainer by @LeonEricsson in #4027
- Add missing trainer docstrings by @albertvillanova in #4030
- Add autodoc for AlignPropTrainer and AlignPropConfig by @albertvillanova in #4033
- 🥓 [docs] add CP docs by @kashif in #3994
- ⚖️ Remove
average_tokens_across_devicesdefault replacement by @qgallouedec in #4039 - CI hotfix: xfail test_training_with_transformers_paged by @albertvillanova in #4046
- Update transformers minimum version to 4.56.1 by @albertvillanova in #4047
- 🧨 DFT by @qgallouedec in #4042
- Update VLM arch check to
AutoModelForImageTextToTextfor DPO and Online DPO by @sergiopaniego in #4049 - 🏂 Fix label shifting logic in
SFTTrainerfor compatibility with CP by @qgallouedec in #4038 - Add autodoc for BestOfNSampler and improve docstrings by @albertvillanova in #4034
- ✨ Improve SFT doc by @qgallouedec in #4005
- 💬 Remove setting chat template in sft script by @qgallouedec in #4037
- 🪪 Update SFTTrainer to handle labels correctly and add configuration example in paper index by @qgallouedec in #4051
- 🗜 Hotfix: avoid passing
quantization_config=Noneby @qgallouedec in #4019 - Release: 0.23 by @qgallouedec in #4053
New Contributors
- @Peter-Chou made their first contribution in #3992
- @ginkyenglee made their first contribution in #3987
- @albertvillanova made their first contribution in #3993
- @hjh0119 made their first contribution in #3998
- @vaelev made their first contribution in #3783
- @dwisdom0 made their first contribution in #4020
Full Changelog: v0.22.0...v0.23.0
v0.22.2
What's Changed
- ⚖️ Fix scale_rewards issue in GRPO by @Peter-Chou in #3992
- ⏰ fix: add return to shift_tokens_right by @ginkyenglee in #3987
- ✖️ Support pad-to-multiple-of and padding-free by @qgallouedec in #3996
New Contributors
- @Peter-Chou made their first contribution in #3992
Full Changelog: v0.22.1...v0.22.2
v0.22.1
What changed
- Refactor version retrieval to use
importlib.metadataby @qgallouedec - Release: 0.22.1 by @qgallouedec
Full Changelog: v0.22.0...v0.22.1
v0.22.0
Major
🔮 Native VLM support for SFTTrainer
SFTTrainer now natively supports Vision-Language Models (VLMs). This includes support for both languauge modeling, prompt-completion data.
It also supports training on completion-only.
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen2.5-VL-3B-Instruct",
args=SFTConfig(max_length=None),
train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"),
)
trainer.train()by @qgallouedec in #3862, #3907 and #3908
🔥 RLOOTrainer refactor
RLOOTrainer has been refactored to align with the design principles of other other trainers in the library. You can now use this trainer exactly like GRPO.
from datasets import load_dataset
from trl import RLOOConfig, RLOOTrainer
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
# Dummy reward function for demonstration purposes
def reward_num_unique_letters(completions, **kwargs):
"""Reward function that rewards completions with more unique letters."""
completion_contents = [completion[0]["content"] for completion in completions]
return [float(len(set(content))) for content in completion_contents]
trainer = RLOOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_letters,
train_dataset=dataset,
)
trainer.train()by @shirinyamani in #3801
🧭 HF jobs x TRL guide
You can now levarage Hugging Face Jobs to easily train and deploy your models with TRL.
hf jobs uv run --flavor a100-large --secrets HF_TOKEN "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" --model_name_or_path Qwen/Qwen2-0.5B --dataset_name trl-lib/CapybaraA guide is available in the docs.
by @sergiopaniego in #3890
🏌️ DAPO loss type
GRPOTrainer now supports DAPO loss type, which aggregates token-level losses by normalizing with the number of active token in the global accumulated batch. This method was introduced to eliminate length bias. Simply use
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
loss_type="dapo",
...
)by @qgallouedec in #3938
🪶 [GRPO] PPO Lite: Scale rewards by Std of Batch
The authors of Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO) find that the combination of:
- scaling rewards by the standard deviation computed over the entire batch and
- aggregating loss over the total number of tokens
can unlock the learning capability of critic-free policies using vanilla PPO loss. Their results demonstrate that this simple combination consistently improves performance, surpassing strategies like GRPO and DAPO.
TRL supports using these learnings to train a GRPO model by:
from trl import GRPOConfig
training_args = GRPOConfig(
scale_rewards="batch",
loss_type="dapo",
...
)by @pramodith in #3935
🎢 [Callbacks] BEMA
Bias-Corrected Exponential Moving Average (BEMA) improves the stability and efficiency of language model fine-tuning by reducing stochasticity and eliminating bias. To use BEMA with SFT as described in the paper, you can now use the [BEMACallback]:
from trl import BEMACallback, SFTTrainer
trainer = SFTTrainer(
...
callbacks=[BEMACallback()],
)Minor
- 🎀 New defaults:
gradient_checkpointing=Trueby @qgallouedec in #3510 - 🎚️ Add dataset mixer by @lewtun in #3791
- 💇 Add soft overlong punishment reward function and update documentation by @qgallouedec in #3804
- 🗿 [CPO] Add AlphaPO method via CPOTrainer by @kashif in #3824
- 🗳️ Extend BCO Trainer dataset format support by @reihig-ut in #3134
- 🐯 Support assistant-only training and Liger by @qgallouedec in #3914
- 🎆 Add entropy logging in SFT by @qgallouedec in #3940
- 📸 Return
position_idsforflash_attention_3by @jue-jue-zi in #3942
Deprecations
- 🗑️ Deprecate
setup_chat_formatby @qgallouedec in #3929 - 🗑 Deprecate
IterativeSFTTrainerby @qgallouedec in #3905
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #3850
- 🔗 Fix collection link in doc by @qgallouedec in #3852
- Typo fix in new model description by @sergiopaniego in #3854
- Small style fix in README by @qgallouedec in #3861
- [GRPO] 👁️ Fix vLLM server mode for VLM GRPO training incompatibility for certain AutoProcessors by @ghubnerr in #3832
- 👁️ From
AutoModelForVision2SeqtoAutoModelForImageTextToTextby @qgallouedec in #3836 - 👋 Remove
--bf16value in scripts by @sergiopaniego in #3869 - 🎀 New defaults:
gradient_checkpointing=Trueby @qgallouedec in #3510 - 🦦 Validate
vllm_modeparam in GRPO by @sergiopaniego in #3866 - 🎚️ Add dataset mixer by @lewtun in #3791
- ✨ Integrate PEFT model preparation across trainers and utilities by @qgallouedec in #3882
- ⌨️ Add py.typed by @cyyever in #3841
- 💇 Add soft overlong punishment reward function and update documentation by @qgallouedec in #3804
- 🕹️ [GRPO] Fix vllm mode validation in distributed setting by @Kirill-Kravtsov in #3886
- ⏳ Replaced
unittest.TestCasewithTrlTestCasethat handles tmp dir by @qgallouedec in #3863 - 🔮 Native VLM support for
SFTTrainerby @qgallouedec in #3862 - Minor optimizations in SFT. by @pramodith in #3884
- 🧩 Fix reward_processing_classes validation in GRPOTrainer by @chi2liu in #3876
- 🎢 [Callbacks] BEMA by @kashif in #3855
- 👁️ VLM blog by @qgallouedec in #3899
- 🪄 Improve quickstart documentation with updated API examples by @behroozazarkhalili in #3873
- 👔 HF Doc Builder style by @qgallouedec in #3498
- ✏️ Fix SFTTrainer token accuracy computation with PromptEncoder by @zk-quantum in #3821
- ☑️ Check eval batch size in grpo by @jp1924 in #3889
- ⚔️ Optimize truncate_with_protected_tokens to use vectorized operations by @chi2liu in #3875
- Add tests for get_position_ids_from_packed_seq_lengths by @pramodith in #3883
- 🌳 Enhance segment tree implementation for non-power-of-2 values by @MengAiDev in #3888
- ⚡ Optimize completion_ids list conversion in GRPO trainer by @chi2liu in #3874
- 🗿 [CPO] Add AlphaPO method via CPOTrainer by @kashif in #3824
- 🗳️ Extend BCO Trainer dataset format support by @reihig-ut in #3134
- 🐯 Support assistant-only training and Liger by @qgallouedec in #3914
- 🗑 Deprecate
IterativeSFTTrainerby @qgallouedec in #3905 - ♻️
use_cacheshould be set in the forward pass by @qgallouedec in #3891 - 🌓 SFTTrainer for VLM: Support for prompt-completion data by @qgallouedec in #3907
- ➡️ SFTTrainer for VLM: support completion-only loss by @qgallouedec in #3908
- 📚 Update BEMACallback documentation to ignore docstyle and fix lag parameter description by @qgallouedec in #3917
- ✏️ Fix typos by @cyyever in #3921
- 🧹 Clean SFT tests by @qgallouedec in #3922
- 🤹♂️ Multi-image testing dataset by @qgallouedec in #3916
- 🧾 Use
logger.warninginstead ofwarnings.warnby @qgallouedec in #3923 - ♻️ Reuse multimodal message preparation from
SFTTrainerinGRPOTrainerby @MengAiDev in #3919 - 🗑️ Deprecate
setup_chat_formatby @qgallouedec in #3929 - 🗞 bugfix 'TrainerState' object is not subscriptable by @ErezYosef in https://github.com/huggingf...