Skip to content

[GKD] Buffer Implementation for Distillation Trainer#5137

Open
cmpatino wants to merge 36 commits intohuggingface:mainfrom
cmpatino:kd-buffering
Open

[GKD] Buffer Implementation for Distillation Trainer#5137
cmpatino wants to merge 36 commits intohuggingface:mainfrom
cmpatino:kd-buffering

Conversation

@cmpatino
Copy link
Collaborator

@cmpatino cmpatino commented Feb 20, 2026

Implement Buffer for Distillation Trainer (GOLDTrainer)

Implement generation buffering and multi-generation support for GOLDTrainer

Add a prompt-level generation buffer that decouples generation from the
optimization steps. We adopt a buffer similar to GRPO to generate all rollouts for all mini-batches within an optimization step, leveraging parallel inference engines. This means each worker handles a buffer of per_device_train_batch_size * gradient_accumulation_steps.

Buffer Details

We allow multiple rollouts per prompt, following Thinking Machine’s Tinker example. The number of rollouts per prompt is determined by the num_generations parameter. To keep the effective batch size constant, we introduce the generation_batch_size parameter, which controls how many unique prompts we pass to the inference engine. We enforce generation_batch_size = per_device_train_batch_size * gradient_accumulation_steps // num_generations to ensure the effective batch size is invariant across setups.

Benchmarks

We can replicate Thinking Machine’s results using both non-Liger and Liger losses, achieving a 3x speedup on a setup with 8 training nodes in colocate mode.

comparison_tinker_liger
Phase Tinker (s) TRL (s)
Sampling 329.83 130
Loss 37.96 -
Training 98.69 38
Total 492.28 173

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.


Note

Medium Risk
Touches core training-loop mechanics (dataloader sampling, buffering, and generation integration), which can affect correctness and throughput across distributed/grad-accumulation setups despite added validation and tests.

Overview
Implements prompt-level rollout buffering in GOLDTrainer so on-policy generations are produced once per optimizer window and reused across gradient_accumulation_steps, including a custom get_train_dataloader() + RepeatSampler strategy and new buffer management (_fill_buffer, slice selection, and logging).

Adds multi-generation support via new GOLDConfig knobs (num_generations, generation_batch_size) with strict validation, updates vLLM generation paths to handle n>1 (with prompt deduplication), and refactors completion processing to rebuild sequences/labels consistently (shared _build_sequence_batch).

Also tightens model/teacher revision handling (teacher_model_revision, revised init kwargs merging), adjusts Liger fused loss invocation (explicit hard/soft weights + ZeRO-3 gather context), updates docs/examples, and adds a targeted unit test ensuring prompt retokenization uses left padding when stitching vLLM completions.

Written by Cursor Bugbot for commit 1cbdc32. This will update automatically on new commits. Configure here.

@cmpatino cmpatino requested a review from qgallouedec March 3, 2026 21:24
@cmpatino cmpatino marked this pull request as ready for review March 3, 2026 21:24
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements prompt-level rollout buffering and multi-generation support for GOLDTrainer, decoupling generation from optimization to improve throughput (similar to GRPO-style buffering).

Changes:

  • Add buffered generation across gradient-accumulation windows, including multi-generation per prompt and vLLM dedup/remapping logic.
  • Introduce new config knobs (num_generations, generation_batch_size) with validation and updated revision handling (student_model_revision vs model_revision).
  • Update docs and the example training script to reflect the new configuration behavior.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
trl/experimental/gold/gold_trainer.py Adds buffered dataloader strategy + vLLM multi-generation processing and related training-step changes.
trl/experimental/gold/gold_config.py Adds num_generations / generation_batch_size and validates optimizer-window batch partitioning.
trl/experimental/gold/gold.py Aligns model revision handling and teacher init kwargs; updates example wiring.
docs/source/gold_trainer.md Documents new buffering knobs, revision behavior, and last-batch drop warning.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@lewtun
Copy link
Member

lewtun commented Mar 4, 2026

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7e9cb5eb56

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

gen_prompts = all_prompts_text
gen_n = 1

completion_ids = self.vllm_client.generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally we would like to use trl.generation.VLLMGeneration, but it can be done in a future PR if it's too hard

@qgallouedec
Copy link
Member

I haven't reviewed it in detail; I have a general idea of what it's about, but I'm leaving the implementation mostly up to you. In future PRs, we can try to align it better with the rest of the codebase, but what matters most right now are the results you're getting.

Make sure to run make precommit to make to CI happy

sampling ratio.
* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows.
`generation_batch_size` is the number of unique prompts per worker per optimizer step.
* `student_model_revision` and `model_revision` – if `student_model_revision` is unset, GOLD uses `model_revision`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

student_model_revision is removed in this pr no?

@@ -365,12 +382,6 @@ class GOLDConfig(SFTConfig):
num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."})

duplicated

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just ensure the CI is green before merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants