Upload testing suite for DistillationTrainer#5615
Conversation
qgallouedec
left a comment
There was a problem hiding this comment.
Thanks for adding tests to DistillationTrainer, it had none, so this is directionally welcome. A few things worth addressing before merge:
-
Heavy use of
DistillationTrainer.__new__(...)+ manual attribute assignment, plus fiveDummy*/custom mock classes. Every test assembles a fake trainer by bypassing__init__and setting ~10 private attributes inline. This couples every test to internal attribute names. Any rename or new attribute used incompute_losssilently breaks the suite. It also cuts against two principles: consistency (the rest oftests/experimental/, e.g.test_gkd_trainer.py, loads a tiny real model and exercises the real__init__) and simplicity (the mock scaffolding and duplicated attribute setup intest_server_teacher_path_handles_variable_prompt_lengths/..._padded_completionsis exactly that). A tiny-model fixture would remove most of it. -
No end-to-end test. One short
trainer.train()on a tiny student+teacher would catch far more than the current mocked suite combined, and would make most of the attribute-juggling tests redundant. Plus it would be more align with the principle of testing behavior over implementation.
| torch.testing.assert_close(local_loss, server_loss) | ||
|
|
||
|
|
||
| def test_sampled_mode_keeps_teacher_argmax_for_forward_support(): |
There was a problem hiding this comment.
this is tautological. The expected value is computed by calling the same private helper _jsd_divergence with the same support/mask construction that compute_loss uses internally, so it tests wiring, not semantics. Any bug inside _jsd_divergence is invisible. Derive the expected value from first principles (plain JSD formula) or drop it.
| report_to="none", | ||
| ) | ||
|
|
||
| assert caught == [] |
There was a problem hiding this comment.
fragile: any unrelated DeprecationWarning from a dependency will break it. The pytest.raises is already the real assertion
There was a problem hiding this comment.
Thank you for the feedback! I addressed your comments with the following changes:
- Followed GKD's tests and used a small model and dataset to test the trainer.
- Improved the Liger tests by following the testing approach from the SFTTrainer.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 1033b7c. Configure here.

What does this PR do?
Uploads tests for the
DistillationTrainerBefore submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Low Risk
Test-only changes; primary risk is increased CI runtime/flakiness due to added integration-style training tests and dataset/model downloads.
Overview
Adds a much broader test suite for
DistillationTrainer, covering config validation, teacher-server request shaping viabuild_teacher_request_inputs, and correct handling of ragged/padded teacher logprob responses (including-infpadding behavior and finite gradients through reverse-KL/JSD paths).Introduces unit tests for
DistillationTrainer.generalized_jsd_loss(beta/temperature/reduction/symmetry/identity cases), a lightweight end-to-endtrain()smoke test (including checkpoint save), an optional Liger-kernel training test, parity checks between local-teacher vs vLLM-teacher-server loss computation (mockedVLLMClient), and a regression test that_RepeatBatchDataLoaderforwardsset_epochcorrectly.Reviewed by Cursor Bugbot for commit b3340a2. Bugbot is set up for automated code reviews on this repo. Configure here.