[GKD] Buffer Implementation for Distillation Trainer#5137
[GKD] Buffer Implementation for Distillation Trainer#5137cmpatino wants to merge 36 commits intohuggingface:mainfrom
Conversation
Avoid crashing when using DeepSpeed ZeRO-3 and set up the correct values for `weight_hard_loss` and `weight_soft_loss`
KD Buffer Simplification
Add scripts to run GOLD
There was a problem hiding this comment.
Pull request overview
Implements prompt-level rollout buffering and multi-generation support for GOLDTrainer, decoupling generation from optimization to improve throughput (similar to GRPO-style buffering).
Changes:
- Add buffered generation across gradient-accumulation windows, including multi-generation per prompt and vLLM dedup/remapping logic.
- Introduce new config knobs (
num_generations,generation_batch_size) with validation and updated revision handling (student_model_revisionvsmodel_revision). - Update docs and the example training script to reflect the new configuration behavior.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
trl/experimental/gold/gold_trainer.py |
Adds buffered dataloader strategy + vLLM multi-generation processing and related training-step changes. |
trl/experimental/gold/gold_config.py |
Adds num_generations / generation_batch_size and validates optimizer-window batch partitioning. |
trl/experimental/gold/gold.py |
Aligns model revision handling and teacher init kwargs; updates example wiring. |
docs/source/gold_trainer.md |
Documents new buffering knobs, revision behavior, and last-batch drop warning. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7e9cb5eb56
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| gen_prompts = all_prompts_text | ||
| gen_n = 1 | ||
|
|
||
| completion_ids = self.vllm_client.generate( |
There was a problem hiding this comment.
ideally we would like to use trl.generation.VLLMGeneration, but it can be done in a future PR if it's too hard
|
I haven't reviewed it in detail; I have a general idea of what it's about, but I'm leaving the implementation mostly up to you. In future PRs, we can try to align it better with the rest of the codebase, but what matters most right now are the results you're getting. Make sure to run |
| sampling ratio. | ||
| * `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. | ||
| `generation_batch_size` is the number of unique prompts per worker per optimizer step. | ||
| * `student_model_revision` and `model_revision` – if `student_model_revision` is unset, GOLD uses `model_revision`. |
There was a problem hiding this comment.
student_model_revision is removed in this pr no?
| @@ -365,12 +382,6 @@ class GOLDConfig(SFTConfig): | |||
| num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."}) | |||
There was a problem hiding this comment.
| num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."}) |
duplicated
qgallouedec
left a comment
There was a problem hiding this comment.
just ensure the CI is green before merging
Implement Buffer for Distillation Trainer (
GOLDTrainer)Implement generation buffering and multi-generation support for GOLDTrainer
Add a prompt-level generation buffer that decouples generation from the
optimization steps. We adopt a buffer similar to GRPO to generate all rollouts for all mini-batches within an optimization step, leveraging parallel inference engines. This means each worker handles a buffer of
per_device_train_batch_size * gradient_accumulation_steps.Buffer Details
We allow multiple rollouts per prompt, following Thinking Machine’s Tinker example. The number of rollouts per prompt is determined by the num_generations parameter. To keep the effective batch size constant, we introduce the generation_batch_size parameter, which controls how many unique prompts we pass to the inference engine. We enforce
generation_batch_size = per_device_train_batch_size * gradient_accumulation_steps // num_generationsto ensure the effective batch size is invariant across setups.Benchmarks
We can replicate Thinking Machine’s results using both non-Liger and Liger losses, achieving a 3x speedup on a setup with 8 training nodes in colocate mode.
Before submitting
Pull Request section?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Note
Medium Risk
Touches core training-loop mechanics (dataloader sampling, buffering, and generation integration), which can affect correctness and throughput across distributed/grad-accumulation setups despite added validation and tests.
Overview
Implements prompt-level rollout buffering in
GOLDTrainerso on-policy generations are produced once per optimizer window and reused acrossgradient_accumulation_steps, including a customget_train_dataloader()+RepeatSamplerstrategy and new buffer management (_fill_buffer, slice selection, and logging).Adds multi-generation support via new
GOLDConfigknobs (num_generations,generation_batch_size) with strict validation, updates vLLM generation paths to handlen>1(with prompt deduplication), and refactors completion processing to rebuild sequences/labels consistently (shared_build_sequence_batch).Also tightens model/teacher revision handling (
teacher_model_revision, revised init kwargs merging), adjusts Liger fused loss invocation (explicit hard/soft weights + ZeRO-3 gather context), updates docs/examples, and adds a targeted unit test ensuring prompt retokenization uses left padding when stitching vLLM completions.Written by Cursor Bugbot for commit 1cbdc32. This will update automatically on new commits. Configure here.