Skip to content

Commit 5c8f323

Browse files
committed
fix
1 parent 88a6eaf commit 5c8f323

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

swift/megatron/trainers/gkd_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
348348
rollout_src = torch.distributed.get_global_rank(rollout_group, 0)
349349

350350
for encoded_batch in encoded_batches:
351-
input_ids = encoded_batch['input_ids']
351+
opsd_batch = encoded_batch.get('opsd_teacher_batch')
352+
source = opsd_batch if opsd_batch is not None else encoded_batch
353+
input_ids = source['input_ids']
352354
device = input_ids.device
353355

354356
if rollout_rank == 0:
@@ -370,7 +372,6 @@ def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
370372
encoded_batch['teacher_api_indices'] = teacher_indices
371373
encoded_batch['teacher_logits'] = None
372374

373-
opsd_batch = encoded_batch.get('opsd_teacher_batch')
374375
if opsd_batch is not None:
375376
encoded_batch['opsd_teacher_labels'] = opsd_batch.get('labels')
376377

swift/rlhf_trainers/gkd_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,9 @@ def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tenso
590590
Returns:
591591
Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk]
592592
"""
593-
input_ids = encoded_inputs['input_ids']
593+
opsd_teacher_inputs = encoded_inputs.get('_opsd_teacher_inputs')
594+
source = opsd_teacher_inputs if opsd_teacher_inputs is not None else encoded_inputs
595+
input_ids = source['input_ids']
594596
teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
595597
self.teacher_model_server, input_ids.tolist(), topk=self.gkd_logits_topk)
596598
return teacher_logprobs.to(input_ids.device), teacher_indices.to(input_ids.device)

0 commit comments

Comments
 (0)