Skip to content

Commit fd35140

Browse files
committed
update
1 parent fc1b673 commit fd35140

File tree

11 files changed

+249
-155
lines changed

11 files changed

+249
-155
lines changed

docs/source/Instruction/GKD.md

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,39 +196,36 @@ swift rlhf \
196196
| `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 |
197197

198198
**支持的后端**
199-
- `swift deploy`(vLLM backend)
200-
- 独立 vLLM 服务(`vllm serve`
199+
- `vllm serve`(推荐)
200+
201+
> **注意**:仅支持 `vllm serve` 作为教师服务后端。训练代码通过 `/v1/completions` 接口直接传递 token IDs 并使用 `prompt_logprobs` 参数获取输入 token 的 log 概率,这是 vLLM 原生支持的功能。
201202
202203
**步骤 1:部署教师模型服务**
203204

204205
```bash
205-
# 使用 swift deploy 部署教师模型
206-
CUDA_VISIBLE_DEVICES=0,1 swift deploy \
207-
--model Qwen/Qwen2-72B-Instruct \
208-
--infer_backend vllm \
206+
# 使用 vllm serve 部署教师模型
207+
CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
209208
--port 8000 \
210-
--vllm_engine_kwargs '{"max_logprobs": 64}'
211-
212-
# 或使用独立 vLLM 服务
213-
vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000
209+
--max-logprobs 64 \
210+
--gpu-memory-utilization 0.9
214211
```
215212

216213
**步骤 2:启动 GKD 训练**
217214

218215
```bash
219216
swift rlhf \
220217
--rlhf_type gkd \
221-
--model Qwen/Qwen2-7B-Instruct \
218+
--model Qwen/Qwen2.5-7B \
222219
--teacher_model_server http://localhost:8000 \
223-
--gkd_logits_topk 20 \
220+
--gkd_logits_topk 64 \
224221
--dataset your_dataset \
225222
--lmbda 1.0 \
226-
--beta 0.5 \
223+
--beta 1.0 \
227224
...
228225
```
229226

230227
> **vLLM max_logprobs 限制**
231-
> - vLLM 默认 `max_logprobs=20`,可通过 `--vllm_engine_kwargs '{"max_logprobs": N}'` 参数调整
228+
> - vLLM 默认 `max_logprobs=20`,可通过 `--max-logprobs N` 参数调整
232229
> - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置
233230
234231
## 采样加速

docs/source/Megatron-SWIFT/GKD.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Megatron GKD 当前已支持以下功能:
3434
| 参数 | 类型 | 默认值 | 说明 |
3535
|------|------|--------|------|
3636
| `--teacher_model` | str | - | 教师模型路径或模型 ID<br>*使用 `teacher_model_server` 时可省略 |
37-
| `--teacher_model_server` | str | None | 教师模型服务地址,如 `http://localhost:8000` |
37+
| `--teacher_model_server` | str | None | 教师模型服务地址(仅支持 `vllm serve`,如 `http://localhost:8000` |
3838
| `--gkd_logits_topk` | int | None | Top-K logits 数量,使用外部教师 API 时必须设置 |
3939
| `--beta` | float | 0.5 | JSD 散度插值系数:<br>• 0.0: Forward KL<br>• 0.5: 对称 JSD<br>• 1.0: Reverse KL |
4040
| `--lmbda` | float | 0.5 | On-Policy 学习触发概率:<br>• 0.0: 纯 Off-Policy<br>• 1.0: 纯 On-Policy |

docs/source_en/Instruction/GKD.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,36 +197,36 @@ When `gkd_logits_topk` is set, you can use an external teacher model API service
197197
| `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API |
198198

199199
**Supported Backends**:
200-
- `swift deploy` (vLLM backend)
201-
- Standalone vLLM server (`vllm serve`)
200+
- `vllm serve` (recommended)
201+
202+
> **Note**: Only `vllm serve` is supported as the teacher server backend. The training code sends raw token IDs via the `prompt` field and uses the `prompt_logprobs` parameter in the `/v1/completions` API to obtain input token log-probabilities. This is a vLLM-native feature.
202203
203204
**Step 1: Deploy Teacher Model Service**
204205

205206
```bash
206-
# Deploy teacher model with swift deploy (recommended)
207-
swift deploy \
208-
--model Qwen/Qwen2.5-14B-Instruct \
209-
--infer_backend vllm \
207+
# Deploy teacher model with vllm serve
208+
CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
210209
--port 8000 \
211-
--vllm_engine_kwargs '{"max_logprobs": 64}'
210+
--max-logprobs 64 \
211+
--gpu-memory-utilization 0.9
212212
```
213213

214214
**Step 2: Start GKD Training**
215215

216216
```bash
217217
swift rlhf \
218218
--rlhf_type gkd \
219-
--model Qwen/Qwen2.5-7B-Instruct \
219+
--model Qwen/Qwen2.5-7B \
220220
--teacher_model_server http://localhost:8000 \
221-
--gkd_logits_topk 20 \
221+
--gkd_logits_topk 64 \
222222
--dataset your_dataset \
223223
--lmbda 1.0 \
224-
--beta 0.5 \
224+
--beta 1.0 \
225225
...
226226
```
227227

228228
> **vLLM max_logprobs Limitation**:
229-
> - vLLM default `max_logprobs=20`, adjustable via `--vllm_engine_kwargs '{"max_logprobs": N}'` parameter
229+
> - vLLM default `max_logprobs=20`, adjustable via `--max-logprobs N` parameter
230230
> - `gkd_logits_topk` cannot exceed the server's `max_logprobs` setting
231231
232232
## Sampling Acceleration

docs/source_en/Megatron-SWIFT/GKD.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Megatron GKD currently supports the following features:
3434
| Parameter | Type | Default | Description |
3535
|-----------|------|---------|-------------|
3636
| `--teacher_model` | str | - | Path or model ID of the teacher model<br>*Can be omitted when using `teacher_model_server` |
37-
| `--teacher_model_server` | str | None | Teacher model service URL, e.g. `http://localhost:8000` |
37+
| `--teacher_model_server` | str | None | Teacher model service URL (`vllm serve` only), e.g. `http://localhost:8000` |
3838
| `--gkd_logits_topk` | int | None | Number of Top-K logits; required when using external API |
3939
| `--beta` | float | 0.5 | JSD divergence interpolation coefficient:<br>• 0.0: Forward KL<br>• 0.5: Symmetric JSD<br>• 1.0: Reverse KL |
4040
| `--lmbda` | float | 0.5 | On-Policy learning probability:<br>• 0.0: Pure Off-Policy<br>• 1.0: Pure On-Policy |

examples/megatron/rlhf/gkd/teacher_server.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# GKD Training with External Teacher Model Server (Megatron)
2+
#
3+
# Start teacher server first (in a separate terminal / GPU):
4+
# CUDA_VISIBLE_DEVICES=4 vllm serve Qwen/Qwen3-8B --port 8000 --max-logprobs 64
5+
16
CUDA_VISIBLE_DEVICES=0,1,2,3 \
27
NPROC_PER_NODE=4 \
38
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# GKD on GSM8K: Teacher Server Mode with Top-K Logits
2+
#
3+
# This script validates GKD effectiveness on mathematical reasoning using GSM8K.
4+
# Student: Qwen2.5-1.5B-Instruct, Teacher: Qwen2.5-7B-Instruct (via vllm serve)
5+
#
6+
# Expected outcome: GSM8K accuracy should improve after GKD training, as the student
7+
# learns the teacher's reasoning distribution on math problems.
8+
#
9+
# ===================== Step 1: Start Teacher Server =====================
10+
# Run in a separate terminal / GPU:
11+
#
12+
# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \
13+
# --port 8000 \
14+
# --max-logprobs 64 \
15+
# --gpu-memory-utilization 0.9
16+
#
17+
# Wait until the server is ready, then verify:
18+
# curl http://localhost:8000/v1/models
19+
# ========================================================================
20+
#
21+
# ===================== Step 2: Prepare GSM8K Dataset =====================
22+
# The dataset uses the standard GSM8K train split from Hugging Face:
23+
# openai/gsm8k (7473 training samples)
24+
# Swift will auto-download it via the HuggingFace dataset name.
25+
# ========================================================================
26+
#
27+
# ===================== Step 3: Evaluation =====================
28+
# After training, evaluate on GSM8K test set:
29+
#
30+
# CUDA_VISIBLE_DEVICES=0 swift eval \
31+
# --model <output_dir>/checkpoint-xxx \
32+
# --eval_backend OpenCompass \
33+
# --infer_backend vllm \
34+
# --eval_dataset gsm8k
35+
#
36+
# Compare with the base model to verify improvement:
37+
# CUDA_VISIBLE_DEVICES=0 swift eval \
38+
# --model Qwen/Qwen2.5-1.5B-Instruct \
39+
# --eval_backend OpenCompass \
40+
# --infer_backend vllm \
41+
# --eval_dataset gsm8k
42+
# ========================================================================
43+
44+
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"}
45+
GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64}
46+
47+
CUDA_VISIBLE_DEVICES=1 \
48+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
49+
swift rlhf \
50+
--rlhf_type gkd \
51+
--model Qwen/Qwen2.5-1.5B-Instruct \
52+
--teacher_model_server $TEACHER_SERVER_URL \
53+
--gkd_logits_topk $GKD_LOGITS_TOPK \
54+
--tuner_type lora \
55+
--lora_rank 64 \
56+
--lora_alpha 128 \
57+
--dataset 'openai/gsm8k#train' \
58+
--seq_kd false \
59+
--lmbda 0 \
60+
--beta 0.5 \
61+
--torch_dtype bfloat16 \
62+
--num_train_epochs 3 \
63+
--per_device_train_batch_size 2 \
64+
--per_device_eval_batch_size 2 \
65+
--learning_rate 5e-5 \
66+
--gradient_accumulation_steps 8 \
67+
--eval_steps 200 \
68+
--save_steps 200 \
69+
--save_total_limit 3 \
70+
--logging_steps 5 \
71+
--max_length 1024 \
72+
--warmup_ratio 0.05 \
73+
--save_only_model true \
74+
--dataloader_num_workers 4 \
75+
--dataset_num_proc 4 \
76+
--deepspeed zero2 \
77+
--attn_impl flash_attn

examples/train/rlhf/gkd/teacher_server.sh

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1-
# GKD Training with External Teacher Model Server
1+
# GKD Training with External Teacher Model Server (vLLM)
22
#
33
# This script demonstrates using an external vLLM server as the teacher model
4-
# for knowledge distillation.
4+
# for knowledge distillation. The teacher server provides prompt_logprobs via
5+
# the /v1/completions endpoint, which requires native vLLM serving (vllm serve).
6+
#
7+
# NOTE: Only `vllm serve` is supported as the teacher server backend, because
8+
# the training code sends raw token IDs via the `prompt` field and uses the
9+
# `prompt_logprobs` parameter in the /v1/completions API. This is a vLLM-native
10+
# feature not available through swift deploy.
511

6-
# Teacher Server Setup (run in a separate gpu):
7-
# CUDA_VISIBLE_DEVICES=5 swift deploy \
8-
# --model Qwen/Qwen2.5-14B-Instruct \
9-
# --infer_backend vllm \
10-
# --port 8000 \
11-
# --vllm_engine_kwargs '{"max_logprobs": 64}'
12+
# ===================== Step 1: Start Teacher Server =====================
13+
# Run in a separate terminal / GPU:
14+
#
15+
# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
16+
# --port 8000 \
17+
# --max-logprobs 64 \
18+
# --gpu-memory-utilization 0.9
19+
#
20+
# Wait until the server is ready (shows "Uvicorn running on ...").
21+
# Verify with: curl http://localhost:8000/v1/models
22+
# ========================================================================
1223

13-
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"}
24+
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"}
1425
GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64}
1526

1627
NPROC_PER_NODE=4 \
1728
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
18-
CUDA_VISIBLE_DEVICES=0,1,2,3 \
29+
CUDA_VISIBLE_DEVICES=1,2,3,4 \
1930
swift rlhf \
2031
--rlhf_type gkd \
2132
--model Qwen/Qwen2.5-7B \

swift/megatron/trainers/gkd_trainer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,7 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O
285285
teacher_logits = teacher_logits.detach()
286286

287287
if topk is not None and teacher_logits is not None:
288-
scaled = teacher_logits / self.temperature
289-
topk_logits, topk_indices = torch.topk(scaled, k=topk, dim=-1)
288+
topk_logits, topk_indices = torch.topk(teacher_logits, k=topk, dim=-1)
290289
encoded_batch['teacher_api_logprobs'] = topk_logits
291290
encoded_batch['teacher_api_indices'] = topk_indices
292291
encoded_batch['teacher_logits'] = None
@@ -295,12 +294,16 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O
295294

296295
def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
297296
"""Fetch teacher logprobs from external API service."""
298-
from swift.rlhf_trainers.teacher_api_client import fetch_teacher_logprobs
297+
from swift.rlhf_trainers.gkd_trainer import fetch_teacher_logprobs
299298
topk = self.gkd_logits_topk
300299
for encoded_batch in encoded_batches:
301300
input_ids = encoded_batch['input_ids']
302301
teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
303302
self.teacher_model_server, input_ids.tolist(), topk=topk)
303+
# fetch_teacher_logprobs returns [batch, seq_len-1, topk] (shifted).
304+
# Pad last position with -inf to match student [batch, seq_len, topk].
305+
teacher_logprobs = F.pad(teacher_logprobs, (0, 0, 0, 1), value=float('-inf'))
306+
teacher_indices = F.pad(teacher_indices, (0, 0, 0, 1), value=0)
304307
encoded_batch['teacher_api_logprobs'] = teacher_logprobs.to(input_ids.device)
305308
encoded_batch['teacher_api_indices'] = teacher_indices.to(input_ids.device)
306309
encoded_batch['teacher_logits'] = None
@@ -474,14 +477,14 @@ def generalized_jsd_loss(
474477
def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta):
475478
"""Compute JSD on teacher's top-k distribution.
476479
477-
Handles both local top-k (raw logits) and API top-k (raw logprobs) by
478-
normalizing both teacher and student over the top-k subset via log_softmax.
480+
Both local and API teacher are handled uniformly: gather student logits at
481+
teacher's top-k indices, scale by 1/T, and log_softmax over top-k subset.
482+
By shift-invariance of log_softmax, this gives identical results whether
483+
teacher_topk_logprobs contains raw logits (local) or raw logprobs (API).
479484
"""
480485
s_scaled = student_logits / self.temperature
481486
s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices)
482-
483-
# Normalize both over top-k subset (handles both raw logits and API logprobs)
484-
t_log_p = F.log_softmax(teacher_topk_logprobs, dim=-1)
487+
t_log_p = F.log_softmax(teacher_topk_logprobs / self.temperature, dim=-1)
485488
s_log_p = F.log_softmax(s_topk, dim=-1)
486489
t_p = torch.exp(t_log_p)
487490

swift/rlhf_trainers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .ppo_trainer import PPOTrainer
1616
from .reward_trainer import RewardTrainer
1717
from .rlhf_mixin import RLHFTrainerMixin
18-
from .teacher_api_client import fetch_teacher_logprobs
1918
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, round_robin
2019
from .vllm_client import VLLMClient
2120
else:
@@ -32,7 +31,6 @@
3231
'args_mixin': ['VllmArguments', 'GRPOArgumentsMixin'],
3332
'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'],
3433
'vllm_client': ['VLLMClient'],
35-
'teacher_api_client': ['fetch_teacher_logprobs'],
3634
'arguments':
3735
['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig', 'GKDConfig']
3836
}

0 commit comments

Comments
 (0)