Skip to content

Commit 7c6ea38

Browse files
authored
Merge pull request #19 from gyfffffff/main
更新PPT,完善文档和代码细节
2 parents 81699bd + 010145b commit 7c6ea38

18 files changed

+5075
-563
lines changed

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,10 @@ cython_debug/
160160
.idea/
161161
.DS_Store
162162
docs/chapter2/models/GPT-2/*
163-
MetaICL/
163+
MetaICL/
164+
data/
165+
output/
166+
outputs/
167+
168+
docs/chapter2/code/BabyLlama/models/*
169+
*.zip

PPT/distillation.pptx

1.64 MB
Binary file not shown.

docs/chapter2/README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# 2.1 蒸馏
22

3-
本章讲介绍基于Transformer的模型的主流蒸馏方法和代码,还将实现一个端侧部署demo
3+
本章将介绍大模型的主流蒸馏方法和代码
44

55
## Roadmap
66
### 1. 蒸馏基础
@@ -18,7 +18,7 @@
1818
- 2.1 概述
1919
- 何时使用白盒蒸馏
2020
- 2.2 MiniLLM
21-
- 2.3 GKD
21+
- 2.3 BabyLlama
2222

2323
### 3. 基于涌现能力的蒸馏(黑盒蒸馏)
2424
- 3.1 概述
@@ -33,5 +33,3 @@
3333

3434

3535
### 4. 总结
36-
- 4.1 前沿相关工作扩展
37-
- 4.2 总结

docs/chapter2/chapter2_1.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ https://github.com/datawhalechina/awesome-compression/blob/main/docs/ch06/ch06.m
5454

5555
![](images/Figure%206.png)
5656

57+
# 前置知识
58+
了解以下知识有助于接下来的学习:
59+
1. logits 和 软目标
60+
2. 监督微调(SFT技术)
5761

5862

5963
参考文献:

docs/chapter2/chapter2_2.md

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# 白盒蒸馏
22

33
## 1. 什么是白盒蒸馏
4-
对于开源的大模型,我们可以获得一切的模型推理时数据,包括token输出的概率分布。这种能够获得token输出概率分布的场景,可以被看作“白盒”场景。反之是黑盒场景。利用白盒所提供的数据进行蒸馏,是白盒蒸馏。
4+
白盒蒸馏是指在蒸馏过程中使用到教师模型的参数或 logits 的 蒸馏技术[2].
55

66
接下来我们会介绍经典的白盒蒸馏方法和代码实现。
77

@@ -47,18 +47,45 @@ MiniLLM的论文中提出了另一个新颖的视角——逆向KL其实可以
4747
由于这部分涉及较多数学公式推导和强化学习,有兴趣的同学可以查看论文自行学习。
4848

4949
# 3. BabyLlama(实践)
50-
[BabyLlama](http://arxiv.org/abs/2308.02019)将蒸馏看作一种提高训练样本利用效率的有效方式。作为代码实践的例子,我们将看到它的蒸馏损失函数使用到了教师模型的soft-labels。
51-
52-
BabyLlama的代码包含了
53-
1. 数据清洗和tokenizer训练
54-
2. 教师模型训练
55-
3. 蒸馏学生模型
56-
57-
但实际上白盒蒸馏也可以使用现成的开源模型和tokenizer。
58-
59-
50+
[BabyLlama](http://arxiv.org/abs/2308.02019) 将小模型蒸馏直接应用到了大模型上。它的损失函数是以下两种损失的加权和:
51+
- 和硬损失的交叉熵
52+
- 和软损失的KL散度
53+
54+
在code/BabyLlama/3.distill.ipynb中可以看到它的损失函数:
55+
```python
56+
def compute_loss(self, model, inputs, return_outputs=False):
57+
# 硬损失,即和ground truth的交叉熵
58+
outputs_student = model(**inputs)
59+
student_loss = outputs_student.loss
60+
61+
# compute teacher output
62+
with torch.no_grad():
63+
all_teacher_logits = []
64+
for teacher in self.teachers:
65+
outputs_teacher = teacher(**inputs)
66+
all_teacher_logits.append(outputs_teacher.logits)
67+
avg_teacher_logits = torch.stack(all_teacher_logits).mean(dim=0)
68+
69+
# assert size
70+
assert outputs_student.logits.size() == avg_teacher_logits.size()
71+
72+
# 软损失,和教师模型输出分布的KL散度
73+
loss_function = nn.KLDivLoss(reduction="batchmean")
74+
loss_logits = (
75+
loss_function(
76+
F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
77+
F.softmax(avg_teacher_logits / self.args.temperature, dim=-1),
78+
)
79+
* (self.args.temperature ** 2)
80+
)
81+
# Return weighted student loss
82+
loss = self.args.alpha * student_loss + (1.0 - self.args.alpha) * loss_logits
83+
return (loss, outputs_student) if return_outputs else loss
84+
```
6085

6186
## 参考资料
62-
- MiniLLM: Knowledge Distillation of Large Language Models
63-
- https://github.com/microsoft/LMOps/tree/main/minillm
64-
- https://blog.csdn.net/ningmengzhihe/article/details/130679350
87+
1. MiniLLM: Knowledge Distillation of Large Language Models
88+
2. Efficient Large Language Models: A Survey
89+
3. https://github.com/microsoft/LMOps/tree/main/minillm
90+
4. https://blog.csdn.net/ningmengzhihe/article/details/130679350
91+
5. Baby Llama: knowledge distillation from an ensemble of teachers trained on a small dataset with no performance penalty

docs/chapter2/chapter2_3.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# 基于涌现能力的蒸馏(黑盒蒸馏)
1+
# 黑盒蒸馏(技巧蒸馏)
2+
黑盒蒸馏所用到的仅仅是教师模型的回答(有时也包括输出的概率分布,即软目标,但是不会用到logits)。
23

34
黑盒蒸馏意味着教师模型的输出是我们唯一能获取到的训练资源,因此,黑盒蒸馏整体的思路可以分成两步:
45
1. 从教师模型收集问答数据
@@ -17,9 +18,7 @@
1718

1819
以下是一个简单的ICL例子:
1920

20-
<div align="center">
21-
<img src="images/image-1.png" alt="alt text" width="550"/>
22-
</div>
21+
![alt_text](images/image-1.png)
2322

2423
模型成功模仿了示例中的答题思路和答题格式。
2524

@@ -44,9 +43,9 @@ y_3
4443
即只要在prompt前加几个例子,模型就能学到其中的格式和逻辑,从而不用更新参数就能学习。
4544

4645
训练之前,我们会收集如下含有提示词和标签的数据:
47-
<div align="center">
48-
<img src="images/image-3.png" alt="alt text" width="750"/>
49-
</div>
46+
47+
48+
![alt_text](images/image-3.png)
5049

5150
## 1.2 ICL 微调
5251

@@ -132,7 +131,7 @@ $$
132131
$$
133132

134133
## 1.5 实践
135-
134+
参见code/ICL。
136135

137136
## 1.6 改进方向
138137
模型上下文学习的性能和上下文中的例子质量紧密相关,所以有人研究专门设计了一个例子检索器,检索高质量的示例[6]
@@ -175,7 +174,7 @@ $$
175174

176175

177176
## 2.3 指令跟随蒸馏实践
178-
177+
参见 code/InsructFollowing.
179178

180179
## 2.4 对抗蒸馏
181180
对抗蒸馏(adversarial distillation)提出除了可以让知识单向地从教师注入学生,学生也可以产生“反馈”,
@@ -210,12 +209,13 @@ $$
210209
## 步骤3:实践
211210
正式的微调Loss函数是交叉熵损失。
212211

212+
参见code/CoT.
213213

214214

215215

216216
# 4. 扩展和总结
217217
实际上,除了以上三种涌现能力的蒸馏,只要是从教授模型收集某种类型的数据,然后用这些数据微调学生模型,都是黑盒蒸馏的应用范围。
218-
因此,对于一些特定领域和特定需求的任务,也可以使用类似的方法达到希望的效果。比如近期上海交通大学的[O1复现论文](https://arxiv.org/pdf/2410.18982)就是一个很好的对教师模型的推理能力进行蒸馏的例子。
218+
因此,对于一些特定领域和特定需求的任务,也可以使用类似的方法达到希望的效果。比如随着OpenAI 强大的复杂推理模型O1的发布,对推理能力进行蒸馏也可以套用上面的方法。近期上海交通大学的[O1复现论文](https://arxiv.org/pdf/2410.18982)就是一个很好的对教师模型的推理能力进行蒸馏的例子。
219219

220220
但是也有研究[5]指出黑盒蒸馏导致仅模仿但不理解的问题,要提高学习质量,还需学生有良好的天赋(base 模型的能力)。
221221

docs/chapter2/chapter2_4.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
# 总结
22

3+
本章中,我们学习了大模型蒸馏的概念,与传统蒸馏的不同,以及主流的大模型蒸馏范式。
4+
笔者认为,不论是白盒还是黑盒蒸馏,大模型蒸馏贯穿始终的思想是“训练数据来源于教师”,而非人为标注或机器标注。
5+
6+
蒸馏无疑是一种低成本高效率的提升小模型能力的方式,也可以说它是一条“捷径”,它的初衷是有在限资源部署更好的模型。
7+
但是作为长期主义的研究工作者,想要提升模型能力,不能一味地依靠蒸馏“走捷径”,还是要从第一性原理出发,从根本上探索提升模型能力的技术路线。

docs/chapter2/code/BabyLlama/1.clean_and_tokenize.ipynb

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -11,11 +11,16 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": null,
14+
"execution_count": 2,
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18-
"# 下载数据:https://osf.io/rduj2"
18+
"# 下载两份数据:https://osf.io/5mk3x, https://osf.io/m48ed\n",
19+
"# 将两份数据解压到当前目录下的data文件夹中\n",
20+
"# data目录结构如下:\n",
21+
"# data/\n",
22+
"# |--train_10M/\n",
23+
"# |--dev/"
1924
]
2025
},
2126
{
@@ -36,23 +41,24 @@
3641
},
3742
{
3843
"cell_type": "code",
39-
"execution_count": 1,
44+
"execution_count": 3,
4045
"metadata": {},
4146
"outputs": [],
4247
"source": [
4348
"from pathlib import Path\n",
44-
"from mrclean import *"
49+
"from mrclean import *\n",
50+
"import os"
4551
]
4652
},
4753
{
4854
"cell_type": "code",
49-
"execution_count": 15,
55+
"execution_count": 4,
5056
"metadata": {},
5157
"outputs": [],
5258
"source": [
53-
"DATA_ROOT = Path(\"F:/llm-deploy-data/data/Babyllama\")\n",
59+
"DATA_ROOT = Path(\"./data\")\n",
5460
"SEQ_LENGTH = 128 # this is a legacy parameter, it does not affect cleaning\n",
55-
"DATA_SPLITS = ['babylm_10M', 'babylm_dev']\n",
61+
"DATA_SPLITS = ['train_10M', 'dev']\n",
5662
"\n",
5763
"CLEANUP_FUNCTIONS = {\n",
5864
" 'aochildes': cleanup_aochildes,\n",
@@ -70,25 +76,25 @@
7076
},
7177
{
7278
"cell_type": "code",
73-
"execution_count": 16,
79+
"execution_count": 5,
7480
"metadata": {},
7581
"outputs": [
7682
{
7783
"name": "stdout",
7884
"output_type": "stream",
7985
"text": [
80-
"🧹 Cleaned 'bnc_spoken.train' (size 4883879 -> 4851676) in babylm_10M\n",
81-
"🧹 Cleaned 'childes.train' (size 15482927 -> 15482927) in babylm_10M\n",
82-
"🧹 Cleaned 'gutenberg.train' (size 13910986 -> 13910986) in babylm_10M\n",
83-
"🧹 Cleaned 'open_subtitles.train' (size 10806305 -> 10804026) in babylm_10M\n",
84-
"🧹 Cleaned 'simple_wiki.train' (size 8411630 -> 8387062) in babylm_10M\n",
85-
"🧹 Cleaned 'switchboard.train' (size 719322 -> 719322) in babylm_10M\n",
86-
"🧹 Cleaned 'bnc_spoken.dev' (size 6538139 -> 6503778) in babylm_dev\n",
87-
"🧹 Cleaned 'childes.dev' (size 14638378 -> 14638378) in babylm_dev\n",
88-
"🧹 Cleaned 'gutenberg.dev' (size 15490473 -> 15490473) in babylm_dev\n",
89-
"🧹 Cleaned 'open_subtitles.dev' (size 11016133 -> 11014854) in babylm_dev\n",
90-
"🧹 Cleaned 'simple_wiki.dev' (size 8149513 -> 8128239) in babylm_dev\n",
91-
"🧹 Cleaned 'switchboard.dev' (size 724013 -> 724013) in babylm_dev\n"
86+
"🧹 Cleaned 'childes.train' (size 15482927 -> 15482927) in train_10M\n",
87+
"🧹 Cleaned 'simple_wiki.train' (size 8411630 -> 8387062) in train_10M\n",
88+
"🧹 Cleaned 'bnc_spoken.train' (size 4883879 -> 4851676) in train_10M\n",
89+
"🧹 Cleaned 'gutenberg.train' (size 13910986 -> 13910986) in train_10M\n",
90+
"🧹 Cleaned 'switchboard.train' (size 719322 -> 719322) in train_10M\n",
91+
"🧹 Cleaned 'open_subtitles.train' (size 10806305 -> 10804026) in train_10M\n",
92+
"🧹 Cleaned 'switchboard.dev' (size 724013 -> 724013) in dev\n",
93+
"🧹 Cleaned 'simple_wiki.dev' (size 8149513 -> 8128239) in dev\n",
94+
"🧹 Cleaned 'gutenberg.dev' (size 15490473 -> 15490473) in dev\n",
95+
"🧹 Cleaned 'bnc_spoken.dev' (size 6538139 -> 6503778) in dev\n",
96+
"🧹 Cleaned 'open_subtitles.dev' (size 11016133 -> 11014854) in dev\n",
97+
"🧹 Cleaned 'childes.dev' (size 14638378 -> 14638378) in dev\n"
9298
]
9399
}
94100
],
@@ -117,7 +123,7 @@
117123
},
118124
{
119125
"cell_type": "code",
120-
"execution_count": 17,
126+
"execution_count": 6,
121127
"metadata": {},
122128
"outputs": [],
123129
"source": [
@@ -129,7 +135,7 @@
129135
},
130136
{
131137
"cell_type": "code",
132-
"execution_count": 18,
138+
"execution_count": 7,
133139
"metadata": {},
134140
"outputs": [
135141
{
@@ -142,7 +148,7 @@
142148
],
143149
"source": [
144150
"# We train the tokenizer on the train data only\n",
145-
"data_dir = Path(\"F:/llm-deploy-data/data/Babyllama/babylm_10M_clean/\")\n",
151+
"data_dir = Path(\"./data/train_10M_clean/\")\n",
146152
"\n",
147153
"paths = [str(f) for f in data_dir.glob(\"*\") if f.is_file() and not f.name.endswith(\".DS_Store\") and f.suffix in [\".train\"]]\n",
148154
"\n",
@@ -153,7 +159,7 @@
153159
},
154160
{
155161
"cell_type": "code",
156-
"execution_count": 19,
162+
"execution_count": 8,
157163
"metadata": {},
158164
"outputs": [],
159165
"source": [
@@ -167,21 +173,32 @@
167173
},
168174
{
169175
"cell_type": "code",
170-
"execution_count": 20,
176+
"execution_count": 9,
171177
"metadata": {},
172-
"outputs": [],
178+
"outputs": [
179+
{
180+
"name": "stdout",
181+
"output_type": "stream",
182+
"text": [
183+
"\n",
184+
"\n",
185+
"\n"
186+
]
187+
}
188+
],
173189
"source": [
174190
"trainer = trainers.BpeTrainer(vocab_size=16000, min_frequency=2, special_tokens=[\"<pad>\", \"<s>\", \"</s>\"])\n",
175191
"tokenizer.train(paths, trainer)"
176192
]
177193
},
178194
{
179195
"cell_type": "code",
180-
"execution_count": 22,
196+
"execution_count": 10,
181197
"metadata": {},
182198
"outputs": [],
183199
"source": [
184-
"tokenizer_path = DATA_ROOT / \"models/gpt-clean-16000.json\"\n",
200+
"tokenizer_path = \"./models/gpt-clean-16000.json\"\n",
201+
"os.makedirs(\"models\", exist_ok=True)\n",
185202
"tokenizer.save(str(tokenizer_path), pretty=True)"
186203
]
187204
},
@@ -194,15 +211,15 @@
194211
},
195212
{
196213
"cell_type": "code",
197-
"execution_count": 23,
214+
"execution_count": 11,
198215
"metadata": {},
199216
"outputs": [
200217
{
201218
"name": "stdout",
202219
"output_type": "stream",
203220
"text": [
204221
"Encoded String: ['ĠThe', 'Ġquick', 'Ġbrown', 'Ġfox', 'Ġjumps', 'Ġover', 'Ġthe', 'Ġlazy', 'Ġdog', '.']\n",
205-
"Encoded IDs: [302, 1784, 3266, 5712, 15961, 541, 190, 11553, 1469, 16]\n",
222+
"Encoded IDs: [300, 1782, 3264, 5710, 15959, 539, 188, 11551, 1467, 16]\n",
206223
"Decoded String: The quick brown fox jumps over the lazy dog.\n"
207224
]
208225
}
@@ -248,7 +265,7 @@
248265
"name": "python",
249266
"nbconvert_exporter": "python",
250267
"pygments_lexer": "ipython3",
251-
"version": "3.9.18"
268+
"version": "3.9.20"
252269
},
253270
"orig_nbformat": 4
254271
},

0 commit comments

Comments
 (0)