Skip to content

Commit fc1b673

Browse files
committed
clean
1 parent 44f0e4e commit fc1b673

File tree

8 files changed

+71
-143
lines changed

8 files changed

+71
-143
lines changed

swift/megatron/arguments/megatron_args.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,7 @@ class RLHFMegatronArgumentsMixin:
5252
'URL of the teacher model server (e.g., http://localhost:8000). '
5353
'When set, teacher logprobs are fetched via API instead of loading a local model.'
5454
})
55-
gkd_logits_topk: Optional[int] = field(
56-
default=None,
57-
metadata={
58-
'help':
59-
'Number of top-k logits for KL computation in GKD. '
60-
'None = full vocabulary, positive integer = top-k only. '
61-
'When using teacher_model_server, limited by server max_logprobs (vLLM default: 20).'
62-
})
55+
gkd_logits_topk: Optional[int] = None
6356
lmbda: float = 0.5 # On-policy probability: with prob lmbda, use student-generated responses
6457
seq_kd: bool = False # Sequential KD: use teacher-generated responses when not on-policy
6558
offload_teacher_model: bool = False # Offload teacher model to CPU to save GPU memory

swift/megatron/pipelines/train/rlhf.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def prepare_trainer(self):
3131
kwargs = {}
3232
if args.rlhf_type in ('grpo', 'gkd'):
3333
kwargs['vllm_client'] = self._prepare_vllm_client()
34-
if args.rlhf_type == 'gkd':
35-
kwargs['teacher_api_client'] = self._prepare_teacher_api_client()
3634
return trainer_cls(args, self.template, **kwargs)
3735

3836
def _prepare_template(self) -> None:
@@ -70,19 +68,6 @@ def _prepare_vllm_client(self):
7068
logger.info('Connected to vLLM server')
7169
return vllm_client
7270

73-
def _prepare_teacher_api_client(self):
74-
"""Prepare teacher API client for external teacher model service.
75-
76-
In Megatron with pure Data Parallel (TP=PP=CP=1), each rank processes different data
77-
and needs its own API client. With model parallelism (TP/PP/CP > 1), one rank per
78-
model parallel group calls the API and broadcasts results.
79-
"""
80-
from swift.rlhf_trainers.utils import create_teacher_api_client
81-
from swift.utils import is_last_rank
82-
if is_last_rank():
83-
return create_teacher_api_client(self.args, check_health=True, timeout=60)
84-
return None
85-
8671

8772
def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None):
8873
return MegatronRLHF(args).main()

swift/megatron/trainers/gkd_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class MegatronGKDTrainer(MegatronRolloutMixin, MegatronRLHFTrainer):
3434

3535
def __init__(self, args: MegatronArguments, template, **kwargs):
3636
self.vllm_client = kwargs.pop('vllm_client', None)
37-
self.teacher_api_client = kwargs.pop('teacher_api_client', None)
3837

3938
# GKD-specific parameters
4039
self.beta = args.beta # JSD interpolation coefficient
@@ -50,7 +49,8 @@ def __init__(self, args: MegatronArguments, template, **kwargs):
5049
self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None)
5150
# Check use_teacher_api based on args, not client existence
5251
# (API client is only created on last rank, but all ranks need to know the mode)
53-
self.use_teacher_api = getattr(args, 'teacher_model_server', None) is not None
52+
self.teacher_model_server = getattr(args, 'teacher_model_server', None)
53+
self.use_teacher_api = self.teacher_model_server is not None
5454

5555
# Validate teacher configuration
5656
if not self.use_teacher_api:
@@ -295,11 +295,12 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O
295295

296296
def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
297297
"""Fetch teacher logprobs from external API service."""
298+
from swift.rlhf_trainers.teacher_api_client import fetch_teacher_logprobs
298299
topk = self.gkd_logits_topk
299300
for encoded_batch in encoded_batches:
300301
input_ids = encoded_batch['input_ids']
301-
teacher_logprobs, teacher_indices = self.teacher_api_client.get_logprobs_sync(
302-
input_ids=input_ids.tolist(), top_logprobs=topk)
302+
teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
303+
self.teacher_model_server, input_ids.tolist(), topk=topk)
303304
encoded_batch['teacher_api_logprobs'] = teacher_logprobs.to(input_ids.device)
304305
encoded_batch['teacher_api_indices'] = teacher_indices.to(input_ids.device)
305306
encoded_batch['teacher_logits'] = None

swift/pipelines/train/rlhf.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,18 +233,9 @@ def _get_trainer_kwargs(self):
233233
if self.args.rlhf_type == 'gkd':
234234
if self.args.teacher_deepspeed:
235235
trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed
236-
# Pass GKD-specific args to trainer
237236
trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk
238-
# Initialize teacher API client if using external teacher service
239237
if self.args.teacher_model_server:
240-
# Pass teacher_model_server so trainer knows to use API mode on all ranks
241238
trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server
242-
from swift.rlhf_trainers.utils import create_teacher_api_client
243-
244-
# In DP mode (DeepSpeed/FSDP), each rank has different data and needs its own client
245-
# Use all_ranks=True so every rank can independently fetch teacher logprobs
246-
trainer_kwargs['teacher_api_client'] = create_teacher_api_client(
247-
self.args, check_health=False, timeout=60)
248239
return trainer_kwargs
249240

250241

swift/rlhf_trainers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .ppo_trainer import PPOTrainer
1616
from .reward_trainer import RewardTrainer
1717
from .rlhf_mixin import RLHFTrainerMixin
18-
from .teacher_api_client import TeacherAPIClient
18+
from .teacher_api_client import fetch_teacher_logprobs
1919
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, round_robin
2020
from .vllm_client import VLLMClient
2121
else:
@@ -32,7 +32,7 @@
3232
'args_mixin': ['VllmArguments', 'GRPOArgumentsMixin'],
3333
'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'],
3434
'vllm_client': ['VLLMClient'],
35-
'teacher_api_client': ['TeacherAPIClient'],
35+
'teacher_api_client': ['fetch_teacher_logprobs'],
3636
'arguments':
3737
['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig', 'GKDConfig']
3838
}

swift/rlhf_trainers/gkd_trainer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
5757
teacher_model = kwargs.pop('teacher_model', None)
5858
teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None)
5959
self.vllm_client = kwargs.pop('vllm_client', None)
60-
self.teacher_api_client = kwargs.pop('teacher_api_client', None)
6160
self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None)
6261
teacher_model_server = kwargs.pop('teacher_model_server', None)
6362
super().__init__(model, None, *_args, **kwargs)
@@ -69,6 +68,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
6968
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
7069
self._total_train_tokens = 0
7170

71+
self.teacher_model_server = teacher_model_server
7272
self.use_teacher_api = teacher_model_server is not None
7373

7474
# Initialize logging components
@@ -469,12 +469,10 @@ def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tenso
469469
Returns:
470470
Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk]
471471
"""
472+
from .teacher_api_client import fetch_teacher_logprobs
472473
input_ids = encoded_inputs['input_ids']
473-
topk = self.gkd_logits_topk
474-
teacher_logprobs, teacher_indices = self.teacher_api_client.get_logprobs_sync(
475-
input_ids=input_ids.tolist(),
476-
top_logprobs=topk,
477-
)
474+
teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
475+
self.teacher_model_server, input_ids.tolist(), topk=self.gkd_logits_topk)
478476
return teacher_logprobs.to(input_ids.device), teacher_indices.to(input_ids.device)
479477

480478
def prediction_step(self, model, inputs, *args, **kwargs):
Lines changed: 59 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
"""Client for fetching teacher model logprobs from OpenAI-compatible endpoints."""
2+
"""Fetch teacher model logprobs from OpenAI-compatible endpoints."""
33
import logging
44
import requests
55
import torch
@@ -8,86 +8,72 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
_model_name_cache: dict = {}
1112

12-
class TeacherAPIClient:
13-
"""Fetch teacher top-k logprobs from an OpenAI-compatible completions API.
13+
14+
def _get_model_name(base_url: str) -> str:
15+
if base_url not in _model_name_cache:
16+
try:
17+
resp = requests.get(f'{base_url}/v1/models', timeout=10)
18+
if resp.ok and resp.json().get('data'):
19+
_model_name_cache[base_url] = resp.json()['data'][0]['id']
20+
except Exception as e:
21+
logger.warning(f'Failed to detect model name: {e}')
22+
if base_url not in _model_name_cache:
23+
_model_name_cache[base_url] = 'default'
24+
return _model_name_cache[base_url]
25+
26+
27+
def fetch_teacher_logprobs(
28+
base_url: str,
29+
input_ids: List[List[int]],
30+
topk: int = 20,
31+
timeout: float = 300.0,
32+
) -> Tuple[torch.Tensor, torch.Tensor]:
33+
"""Fetch top-k logprobs from an OpenAI-compatible completions API.
1434
1535
Args:
1636
base_url: Server URL (e.g., 'http://localhost:8000').
17-
top_logprobs: Number of top log probabilities per token.
37+
input_ids: List of token ID sequences.
38+
topk: Number of top log probabilities per token.
1839
timeout: Request timeout in seconds.
19-
"""
2040
21-
def __init__(self, base_url: str, top_logprobs: int = 20, timeout: float = 300.0):
22-
self.base_url = base_url.rstrip('/')
23-
self.top_logprobs = top_logprobs
24-
self.timeout = timeout
25-
self._model_name = None
41+
Returns:
42+
(logprobs, indices) tensors of shape [batch, max_seq_len, topk].
43+
"""
44+
base_url = base_url.rstrip('/')
45+
batch_size = len(input_ids)
46+
max_seq_len = max(len(ids) for ids in input_ids)
47+
url = f'{base_url}/v1/completions'
48+
model = _get_model_name(base_url)
2649

27-
@property
28-
def model_name(self) -> str:
29-
if self._model_name is None:
30-
try:
31-
resp = requests.get(f'{self.base_url}/v1/models', timeout=10)
32-
if resp.ok and resp.json().get('data'):
33-
self._model_name = resp.json()['data'][0]['id']
34-
except Exception as e:
35-
logger.warning(f'Failed to detect model name: {e}')
36-
if self._model_name is None:
37-
self._model_name = 'default'
38-
return self._model_name
50+
logprobs_out = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32)
51+
indices_out = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long)
3952

40-
def check_health(self, timeout: float = 5.0) -> bool:
41-
"""Check if the teacher model server is reachable."""
53+
def _fetch_one(batch_idx: int):
54+
payload = {
55+
'model': model,
56+
'prompt': input_ids[batch_idx],
57+
'max_tokens': 0,
58+
'temperature': 0,
59+
'logprobs': topk,
60+
'echo': True,
61+
}
4262
try:
43-
resp = requests.get(f'{self.base_url}/v1/models', timeout=timeout)
44-
return resp.ok
45-
except requests.RequestException:
46-
return False
47-
48-
def get_logprobs_sync(
49-
self,
50-
input_ids: List[List[int]],
51-
top_logprobs: Optional[int] = None,
52-
) -> Tuple[torch.Tensor, torch.Tensor]:
53-
"""Fetch top-k logprobs for a batch of token sequences.
54-
55-
Returns:
56-
(logprobs, indices) tensors of shape [batch, max_seq_len, topk].
57-
"""
58-
topk = top_logprobs or self.top_logprobs
59-
batch_size = len(input_ids)
60-
max_seq_len = max(len(ids) for ids in input_ids)
61-
url = f'{self.base_url}/v1/completions'
62-
model = self.model_name
63-
64-
logprobs_out = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32)
65-
indices_out = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long)
66-
67-
def _fetch_one(batch_idx: int):
68-
payload = {
69-
'model': model,
70-
'prompt': input_ids[batch_idx],
71-
'max_tokens': 0,
72-
'temperature': 0,
73-
'logprobs': topk,
74-
'echo': True,
75-
}
76-
try:
77-
resp = requests.post(url, json=payload, timeout=self.timeout)
78-
resp.raise_for_status()
79-
top_logprobs_list = resp.json()['choices'][0].get('logprobs', {}).get('top_logprobs', [])
80-
for pos, pos_lp in enumerate(top_logprobs_list):
81-
if pos_lp is None:
82-
continue
83-
sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1])[:topk]
84-
for k, (tid_str, lp) in enumerate(sorted_items):
85-
indices_out[batch_idx, pos, k] = int(tid_str)
86-
logprobs_out[batch_idx, pos, k] = lp
87-
except Exception as e:
88-
logger.error(f'Failed to get logprobs for sequence {batch_idx}: {e}')
63+
resp = requests.post(url, json=payload, timeout=timeout)
64+
resp.raise_for_status()
65+
top_logprobs_list = resp.json()['choices'][0].get('logprobs', {}).get('top_logprobs', [])
66+
for pos, pos_lp in enumerate(top_logprobs_list):
67+
if pos_lp is None:
68+
continue
69+
sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1])[:topk]
70+
for k, (tid_str, lp) in enumerate(sorted_items):
71+
indices_out[batch_idx, pos, k] = int(tid_str)
72+
logprobs_out[batch_idx, pos, k] = lp
73+
except Exception as e:
74+
logger.error(f'Failed to get logprobs for sequence {batch_idx}: {e}')
8975

90-
with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool:
91-
list(pool.map(_fetch_one, range(batch_size)))
76+
with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool:
77+
list(pool.map(_fetch_one, range(batch_size)))
9278

93-
return logprobs_out, indices_out
79+
return logprobs_out, indices_out

swift/rlhf_trainers/utils.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,32 +1472,6 @@ def check_vllm_version_ge(min_version: str) -> bool:
14721472
return version.parse(vllm_version) >= version.parse(min_version)
14731473

14741474

1475-
def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60):
1476-
"""Create TeacherAPIClient for external teacher model service.
1477-
1478-
Returns:
1479-
TeacherAPIClient instance or None if teacher_model_server is not set
1480-
"""
1481-
teacher_model_server = getattr(args, 'teacher_model_server', None)
1482-
if not teacher_model_server:
1483-
return None
1484-
1485-
from swift.rlhf_trainers import TeacherAPIClient
1486-
1487-
logger = get_logger()
1488-
gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) or 20
1489-
1490-
logger.info(f'Initializing teacher API client for {teacher_model_server}')
1491-
teacher_api_client = TeacherAPIClient(
1492-
base_url=teacher_model_server,
1493-
top_logprobs=gkd_logits_topk,
1494-
)
1495-
if check_health:
1496-
teacher_api_client.check_health(timeout=timeout)
1497-
logger.info(f'Teacher API client initialized with top_logprobs={gkd_logits_topk}')
1498-
return teacher_api_client
1499-
1500-
15011475
# ============================================================================
15021476
# Padding-free utilities
15031477
# ============================================================================

0 commit comments

Comments
 (0)