-
Notifications
You must be signed in to change notification settings - Fork 994
Add conditional distillation support for GKD trainer #6542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add conditional distillation support for GKD trainer #6542
Conversation
Summary of ChangesHello @woshixiaobai2019, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the GKD (Generative Knowledge Distillation) trainer by introducing conditional distillation. This new capability allows for more flexible knowledge transfer by enabling the teacher and student models to process different input contexts or prompts while distilling knowledge. The feature is implemented through an extensible Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
本次PR成功地为GKD训练器增加了条件蒸馏功能,这是一个非常有价值的新特性。整体实现结构清晰,特别是TeacherAdapter的插件化设计,使得扩展教师模型上下文转换逻辑变得非常方便,值得称赞。gkd_trainer.py中处理不同prompt和对齐logits计算损失的核心逻辑虽然复杂,但实现看起来是正确的。文档更新清晰明了,有助于用户理解和使用新功能。同时注意到PR中包含了大量f-string引号风格的全局 linting 修复,这有助于保持代码风格的一致性。
| if teacher_history and teacher_history[0]['role'] == 'system': | ||
| teacher_history[0]['content'] += '\n\n你是一位专业领域专家。' | ||
| else: | ||
| teacher_history.insert(0, { | ||
| 'role': 'system', | ||
| 'content': '你是一位专业领域专家。' | ||
| }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档中 MyTeacherAdapter 的示例代码使用了 teacher_history[0]['content'] += ... 的方式来修改系统提示。由于 history.copy() 是浅拷贝,这样做会意外地修改原始的 history 对象中的内容。虽然在当前 GKD 的流程中可能不会引发问题,但这是一种有风险的实践。建议将文档中的示例更新为更安全的实现方式,即创建一个新的字典来替换 teacher_history[0],就像 swift/plugin/teacher_adapter.py 中 MathTeacherAdapter 的实现一样,以避免潜在的副作用。
| if teacher_history and teacher_history[0]['role'] == 'system': | |
| teacher_history[0]['content'] += '\n\n你是一位专业领域专家。' | |
| else: | |
| teacher_history.insert(0, { | |
| 'role': 'system', | |
| 'content': '你是一位专业领域专家。' | |
| }) | |
| teacher_history = history.copy() | |
| if teacher_history and teacher_history[0]['role'] == 'system': | |
| # 更健壮的方式:创建一个新字典以避免副作用 | |
| teacher_history[0] = { | |
| 'role': 'system', | |
| 'content': teacher_history[0]['content'] + '\n\n你是一位专业领域专家。' | |
| } | |
| else: | |
| teacher_history.insert(0, { | |
| 'role': 'system', | |
| 'content': '你是一位专业领域专家。' | |
| }) |
| if teacher_history and teacher_history[0]['role'] == 'system': | ||
| teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.' | ||
| else: | ||
| teacher_history.insert(0, { | ||
| 'role': 'system', | ||
| 'content': 'You are an expert with extensive knowledge.' | ||
| }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example code for MyTeacherAdapter in the documentation uses teacher_history[0]['content'] += ... to modify the system prompt. Since history.copy() performs a shallow copy, this approach will unintentionally modify the content of the original history object. While this might not cause issues in the current GKD workflow, it is a risky practice. It's recommended to update the example to a safer implementation by creating a new dictionary to replace teacher_history[0], similar to how MathTeacherAdapter is implemented in swift/plugin/teacher_adapter.py, to prevent potential side effects.
| if teacher_history and teacher_history[0]['role'] == 'system': | |
| teacher_history[0]['content'] += '\n\nYou are an expert with extensive knowledge.' | |
| else: | |
| teacher_history.insert(0, { | |
| 'role': 'system', | |
| 'content': 'You are an expert with extensive knowledge.' | |
| }) | |
| teacher_history = history.copy() | |
| if teacher_history and teacher_history[0]['role'] == 'system': | |
| # More robust way: create a new dict to avoid side effects | |
| teacher_history[0] = { | |
| 'role': 'system', | |
| 'content': teacher_history[0]['content'] + '\n\nYou are an expert with extensive knowledge.' | |
| } | |
| else: | |
| teacher_history.insert(0, { | |
| 'role': 'system', | |
| 'content': 'You are an expert with extensive knowledge.' | |
| }) |
| if 'teacher_prompt_attention_mask' in inputs: | ||
| teacher_prompts_len = inputs['teacher_prompt_attention_mask'].sum(dim=1) # [batch_size] | ||
| else: | ||
| teacher_prompts_len = torch.full((inputs['teacher_prompts'].shape[0],), | ||
| inputs['teacher_prompts'].shape[1], | ||
| device=inputs['teacher_prompts'].device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Change shape_context() to accept complete data dict instead of just messages - Allow adapter to access all fields (dataset, images, etc.) for flexible usage - Follow GRPO's reward_model_plugin design pattern
…on control 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
PR type
PR information
条件蒸馏(Conditional Distillation)
条件蒸馏允许教师模型和学生模型使用不同的上下文或提示词进行训练,从而实现更灵活的知识迁移策略。例如:
TeacherAdapter 插件系统
通过实现
TeacherAdapter接口,可以自定义教师模型的上下文转换逻辑:内置 Adapter
SWIFT 提供两个内置的 teacher adapter:
defaultexample使用方法
swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-0.5B-Instruct \ --teacher_model Qwen/Qwen2.5-7B-Instruct \ --teacher_adapter example \ --dataset your_dataset.jsonl \ ...工作原理
在条件蒸馏中:
[prompt_student] + [response][prompt_teacher] + [response]其中
prompt_teacher由teacher_adapter.shape_context()从prompt_student转换而来,而response部分保持不变。训练脚本参考这里
实验结果