Skip to content

Commit d67b0f8

Browse files
committed
wip
1 parent a6a5aea commit d67b0f8

File tree

12 files changed

+981
-462
lines changed

12 files changed

+981
-462
lines changed

docs/source_en/Instruction/GKD.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ $$
178178
```bash
179179
swift rlhf \
180180
--rlhf_type gkd \
181-
--model Qwen/Qwen2-7B-Instruct \
182-
--teacher_model Qwen/Qwen2-72B-Instruct \
181+
--model Qwen/Qwen2.5-7B-Instruct \
182+
--teacher_model Qwen/Qwen2.5-14B-Instruct \
183183
--gkd_logits_topk 64 \
184184
--dataset your_dataset \
185185
...
@@ -204,22 +204,19 @@ When `gkd_logits_topk` is set, you can use an external teacher model API service
204204

205205
```bash
206206
# Deploy teacher model with swift deploy (recommended)
207-
CUDA_VISIBLE_DEVICES=0,1 swift deploy \
208-
--model Qwen/Qwen2-72B-Instruct \
207+
swift deploy \
208+
--model Qwen/Qwen2.5-14B-Instruct \
209209
--infer_backend vllm \
210210
--port 8000 \
211211
--vllm_engine_kwargs '{"max_logprobs": 64}'
212-
213-
# Or use standalone vLLM server
214-
vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000
215212
```
216213

217214
**Step 2: Start GKD Training**
218215

219216
```bash
220217
swift rlhf \
221218
--rlhf_type gkd \
222-
--model Qwen/Qwen2-7B-Instruct \
219+
--model Qwen/Qwen2.5-7B-Instruct \
223220
--teacher_model_server http://localhost:8000 \
224221
--gkd_logits_topk 20 \
225222
--dataset your_dataset \
Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,39 @@
1-
# Megatron GKD Training with External Teacher Model Server
2-
#
3-
# This script demonstrates using an external vLLM server as the teacher model
4-
# for knowledge distillation with Megatron-SWIFT. This approach is useful when:
5-
# - The teacher model is too large to load alongside the student model
6-
# - You want to separate teacher inference from training for better resource utilization
7-
# - You need to use different model parallelism for student vs teacher
8-
#
9-
# Prerequisites:
10-
# 1. Start the teacher model server first (see below)
11-
# 2. Ensure the server is accessible at the specified URL
12-
#
13-
# Teacher Server Setup (run in a separate terminal):
14-
# CUDA_VISIBLE_DEVICES=4,5,6,7 swift deploy \
15-
# --model Qwen/Qwen2-72B-Instruct \
16-
# --infer_backend vllm \
17-
# --port 8000 \
18-
# --vllm_engine_kwargs '{"max_logprobs": 64}'
19-
#
20-
# Or using vLLM directly:
21-
# vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000
22-
23-
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"}
24-
GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-20}
25-
261
CUDA_VISIBLE_DEVICES=0,1,2,3 \
272
NPROC_PER_NODE=4 \
283
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
294
megatron rlhf \
305
--rlhf_type gkd \
316
--model Qwen/Qwen3-8B-Base \
32-
--teacher_model_server $TEACHER_SERVER_URL \
33-
--gkd_logits_topk $GKD_LOGITS_TOPK \
7+
--teacher_model_server http://localhost:8000 \
8+
--gkd_logits_topk 20 \
349
--tuner_type lora \
35-
--dataset 'AI-ModelScope/alpaca-gpt4-data-en#2000' 'AI-ModelScope/alpaca-gpt4-data-zh#2000' \
36-
--tensor_model_parallel_size 2 \
10+
--dataset AI-ModelScope/alpaca-gpt4-data-en#2000 AI-ModelScope/alpaca-gpt4-data-zh#2000 \
11+
--tensor_model_parallel_size 1 \
3712
--expert_model_parallel_size 1 \
3813
--pipeline_model_parallel_size 1 \
39-
--context_parallel_size 2 \
14+
--context_parallel_size 1 \
4015
--seq_kd false \
41-
--lmbda 0 \
42-
--beta 0.5 \
16+
--lmbda 1 \
17+
--beta 1 \
4318
--torch_dtype bfloat16 \
4419
--micro_batch_size 2 \
4520
--global_batch_size 16 \
4621
--max_epochs 1 \
47-
--lr 5e-6 \
48-
--log_interval 5 \
49-
--max_length 4096 \
50-
--max_completion_length 1024 \
22+
--lr 5e-5 \
23+
--log_interval 1 \
24+
--max_length 8192 \
25+
--max_completion_length 8192 \
5126
--attention_backend flash \
27+
--use_vllm true \
28+
--vllm_mode colocate \
29+
--vllm_gpu_memory_utilization 0.5 \
30+
--vllm_tensor_parallel_size 1 \
31+
--vllm_max_model_len 16384 \
32+
--sleep_level 1 \
5233
--recompute_granularity selective \
5334
--finetune \
5435
--no_save_optim \
5536
--no_save_rng \
56-
--temperature 0.9 \
37+
--temperature 1.0 \
5738
--padding_free true \
5839
--sequence_parallel true

examples/train/rlhf/gkd/teacher_server.sh

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
11
# GKD Training with External Teacher Model Server
22
#
33
# This script demonstrates using an external vLLM server as the teacher model
4-
# for knowledge distillation. This approach is useful when:
5-
# - The teacher model is too large to load alongside the student model
6-
# - You want to share a single teacher server across multiple training processes
7-
# - You need more control over the teacher model deployment
8-
#
9-
# Prerequisites:
10-
# 1. Start the teacher model server first (see below)
11-
# 2. Ensure the server is accessible at the specified URL
12-
#
13-
# Teacher Server Setup (run in a separate terminal):
14-
# CUDA_VISIBLE_DEVICES=0,1 swift deploy \
15-
# --model Qwen/Qwen2-72B-Instruct \
16-
# --infer_backend vllm \
17-
# --port 8000 \
18-
# --vllm_engine_kwargs '{"max_logprobs": 64}'
19-
#
20-
# Or using vLLM directly:
21-
# vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000
4+
# for knowledge distillation.
5+
6+
# Teacher Server Setup (run in a separate gpu):
7+
# CUDA_VISIBLE_DEVICES=5 swift deploy \
8+
# --model Qwen/Qwen2.5-14B-Instruct \
9+
# --infer_backend vllm \
10+
# --port 8000 \
11+
# --vllm_engine_kwargs '{"max_logprobs": 64}'
2212

23-
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"}
24-
GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-20}
13+
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"}
14+
GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64}
2515

2616
NPROC_PER_NODE=4 \
2717
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
@@ -30,12 +20,17 @@ swift rlhf \
3020
--rlhf_type gkd \
3121
--model Qwen/Qwen2.5-7B \
3222
--teacher_model_server $TEACHER_SERVER_URL \
23+
--use_vllm true \
24+
--vllm_mode colocate \
25+
--vllm_gpu_memory_utilization 0.5 \
26+
--vllm_tensor_parallel_size 1 \
27+
--vllm_max_model_len 10240 \
3328
--gkd_logits_topk $GKD_LOGITS_TOPK \
34-
--tuner_type full \
29+
--tuner_type lora \
3530
--dataset 'AI-ModelScope/alpaca-gpt4-data-en' \
3631
--seq_kd false \
37-
--lmbda 0 \
38-
--beta 0.5 \
32+
--lmbda 1 \
33+
--beta 1 \
3934
--torch_dtype bfloat16 \
4035
--max_epochs 1 \
4136
--per_device_train_batch_size 1 \
@@ -47,8 +42,7 @@ swift rlhf \
4742
--save_total_limit 2 \
4843
--logging_steps 5 \
4944
--max_length 2048 \
50-
--max_completion_length 512 \
51-
--output_dir output/gkd_teacher_server \
45+
--max_completion_length 2048 \
5246
--warmup_ratio 0.05 \
5347
--save_only_model true \
5448
--dataloader_num_workers 4 \

swift/infer_engine/protocol.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class RequestConfig:
173173
stream: bool = False
174174
logprobs: bool = False
175175
top_logprobs: Optional[int] = None
176+
prompt_logprobs: Optional[int] = None # Set to an integer to get top-k logprobs for each prompt token
176177

177178
n: int = 1
178179
best_of: Optional[int] = None
@@ -192,7 +193,6 @@ def __post_init__(self):
192193
@dataclass
193194
class CompletionRequestMixin:
194195
model: str
195-
prompt: str
196196

197197

198198
@dataclass
@@ -393,11 +393,14 @@ class ChatCompletionResponseChoice:
393393
finish_reason: Literal['stop', 'length', None]
394394
logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None
395395
token_ids: Optional[List[int]] = None
396+
# Logprobs for prompt tokens (when prompt_logprobs is requested)
397+
prompt_logprobs: Optional[List[Dict[str, Any]]] = None
396398

397399
def to_cmpl_choice(self) -> 'CompletionResponseChoice':
398400
self = deepcopy(self)
399401
assert not self.message.tool_calls, f'message: {self.message}'
400-
return CompletionResponseChoice(self.index, self.message.content, self.finish_reason, self.logprobs)
402+
return CompletionResponseChoice(self.index, self.message.content, self.finish_reason, self.logprobs,
403+
self.prompt_logprobs)
401404

402405

403406
@dataclass
@@ -423,6 +426,8 @@ class CompletionResponseChoice:
423426
text: str
424427
finish_reason: Literal['stop', 'length', None]
425428
logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None
429+
# Logprobs for prompt tokens (when prompt_logprobs is requested)
430+
prompt_logprobs: Optional[List[Dict[str, Any]]] = None
426431

427432

428433
@dataclass

swift/infer_engine/vllm_engine.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,48 @@ def _get_logprobs(self,
399399
logprobs[token_id] = logprob.logprob
400400
return super()._get_logprobs(logprobs_list, token_ids, top_logprobs)
401401

402+
def _get_prompt_logprobs(
403+
self,
404+
prompt_logprobs: Optional[List[Optional[Dict]]],
405+
prompt_token_ids: List[int],
406+
) -> Optional[List[Dict[str, Any]]]:
407+
if prompt_logprobs is None or not prompt_token_ids:
408+
return None
409+
410+
result = []
411+
for pos_idx, (token_id, pos_logprobs) in enumerate(zip(prompt_token_ids, prompt_logprobs)):
412+
token = self.tokenizer.decode(token_id)
413+
entry = {
414+
'token_id': token_id,
415+
'token': token,
416+
'logprob': None, # Will be filled if available
417+
'top_logprobs': [],
418+
}
419+
420+
if pos_logprobs is not None:
421+
# Get logprob for the actual token at this position
422+
if token_id in pos_logprobs:
423+
logprob_obj = pos_logprobs[token_id]
424+
entry['logprob'] = logprob_obj.logprob if hasattr(logprob_obj, 'logprob') else logprob_obj
425+
426+
# Get top logprobs sorted by probability (descending)
427+
sorted_items = sorted(
428+
pos_logprobs.items(), key=lambda x: -(x[1].logprob if hasattr(x[1], 'logprob') else x[1]))
429+
for tid, logprob_obj in sorted_items:
430+
logprob_val = logprob_obj.logprob if hasattr(logprob_obj, 'logprob') else logprob_obj
431+
if logprob_val == float('-inf'):
432+
continue
433+
t = self.tokenizer.decode(tid)
434+
entry['top_logprobs'].append({
435+
'token_id': tid,
436+
'token': t,
437+
'logprob': logprob_val,
438+
})
439+
440+
result.append(entry)
441+
442+
return result
443+
402444
def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingParams:
403445
kwargs = {'max_tokens': request_config.max_tokens}
404446
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
@@ -424,6 +466,10 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingP
424466
# Return only the sampled token's logprob
425467
kwargs['logprobs'] = 0
426468

469+
# Handle prompt_logprobs: return logprobs for prompt/input tokens
470+
if request_config.prompt_logprobs is not None:
471+
kwargs['prompt_logprobs'] = request_config.prompt_logprobs
472+
427473
# TODO: beam search
428474
for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']:
429475
if hasattr(SamplingParams, key):
@@ -582,13 +628,21 @@ def _create_chat_completion_response(
582628
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)
583629
toolcall = self._get_toolcall(content) # Use content instead of response for tool calls
584630
token_ids = output.token_ids if request_config.return_details else None
631+
632+
# Get prompt logprobs if requested
633+
prompt_logprobs_result = None
634+
if request_config.prompt_logprobs is not None:
635+
prompt_logprobs_result = self._get_prompt_logprobs(result.prompt_logprobs,
636+
list(result.prompt_token_ids))
637+
585638
choice = ChatCompletionResponseChoice(
586639
index=output.index,
587640
message=ChatMessage(
588641
role='assistant', content=content, reasoning_content=reasoning_content, tool_calls=toolcall),
589642
finish_reason=output.finish_reason,
590643
logprobs=logprobs,
591-
token_ids=token_ids)
644+
token_ids=token_ids,
645+
prompt_logprobs=prompt_logprobs_result)
592646
choices.append(choice)
593647
prompt_token_ids = None
594648
images_size = None

swift/megatron/pipelines/train/rlhf.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,29 @@ def _prepare_vllm_client(self):
7777
return vllm_client
7878

7979
def _prepare_teacher_api_client(self):
80-
"""Prepare teacher API client for external teacher model service."""
80+
"""Prepare teacher API client for external teacher model service.
81+
82+
In Megatron with pure Data Parallel (TP=PP=CP=1), each rank processes different data
83+
and needs its own API client. With model parallelism (TP/PP/CP > 1), one rank per
84+
model parallel group calls the API and broadcasts results.
85+
"""
8186
from swift.rlhf_trainers.utils import create_teacher_api_client
82-
return create_teacher_api_client(self.args, check_health=True, timeout=60, use_last_rank=True)
87+
88+
# Check if using pure data parallelism (no model parallelism)
89+
tp = getattr(self.args, 'tensor_model_parallel_size', 1)
90+
pp = getattr(self.args, 'pipeline_model_parallel_size', 1)
91+
cp = getattr(self.args, 'context_parallel_size', 1)
92+
is_pure_dp = (tp == 1 and pp == 1 and cp == 1)
93+
94+
# In pure DP mode, each rank has different data and needs its own client
95+
# In MP mode, only last rank creates client and broadcasts results
96+
return create_teacher_api_client(
97+
self.args,
98+
check_health=True,
99+
timeout=60,
100+
use_last_rank=True,
101+
tokenizer=self.template.tokenizer,
102+
all_ranks=is_pure_dp)
83103

84104

85105
def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None):

0 commit comments

Comments
 (0)