Skip to content

Commit 88a6eaf

Browse files
committed
fix
1 parent fe928a9 commit 88a6eaf

File tree

1 file changed

+38
-42
lines changed

1 file changed

+38
-42
lines changed

swift/megatron/trainers/gkd_trainer.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -301,32 +301,43 @@ def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optiona
301301
self._compute_teacher_logits_local(encoded_batches, vp_stage)
302302

303303
def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None:
304-
teacher_model = self.teacher_models[vp_stage or 0]
305304
topk = self.gkd_logits_topk
306305

307-
for encoded_batch in encoded_batches:
308-
# Use OPSD teacher batch if available, otherwise use student batch
309-
opsd_batch = encoded_batch.get('opsd_teacher_batch')
310-
source = opsd_batch if opsd_batch is not None else encoded_batch
311-
teacher_batch = {
312-
k: v.clone() if isinstance(v, torch.Tensor) else v
313-
for k, v in source.items() if k not in ('data_source', 'opsd_teacher_batch')
314-
}
315-
teacher_data = self._prepare_batch(teacher_batch)
316-
teacher_data.pop('loss_scale', None)
317-
teacher_data.pop('labels', None)
318-
with self.load_teacher_model_context(), torch.no_grad():
306+
if self._is_self_distillation:
307+
teacher_model = self.unwrapped_models[0]
308+
adapter_contexts = []
309+
if self._teacher_use_disable_adapter:
310+
adapter_contexts = [m.disable_adapter() for m in self.peft_models]
311+
outer_context = ContextManagers(adapter_contexts)
312+
else:
313+
teacher_model = self.teacher_models[vp_stage or 0]
314+
outer_context = self.load_teacher_model_context()
315+
316+
with torch.no_grad(), outer_context:
317+
for encoded_batch in encoded_batches:
318+
opsd_batch = encoded_batch.get('opsd_teacher_batch')
319+
source = opsd_batch if opsd_batch is not None else encoded_batch
320+
teacher_batch = {
321+
k: v.clone() if isinstance(v, torch.Tensor) else v
322+
for k, v in source.items() if k not in ('data_source', 'opsd_teacher_batch')
323+
}
324+
teacher_data = self._prepare_batch(teacher_batch)
325+
teacher_data.pop('loss_scale', None)
326+
opsd_teacher_labels = teacher_data.pop('labels', None)
319327
teacher_logits = forward_step_helper(self.args, teacher_model, teacher_data)
320328
if teacher_logits is not None:
321329
teacher_logits = teacher_logits.detach()
322330

323-
if topk is not None and teacher_logits is not None:
324-
topk_logits, topk_indices = self._vocab_parallel_topk(teacher_logits, k=topk)
325-
encoded_batch['teacher_api_logprobs'] = topk_logits
326-
encoded_batch['teacher_api_indices'] = topk_indices
327-
encoded_batch['teacher_logits'] = None
328-
else:
329-
encoded_batch['teacher_logits'] = teacher_logits
331+
if topk is not None and teacher_logits is not None:
332+
topk_logits, topk_indices = self._vocab_parallel_topk(teacher_logits, k=topk)
333+
encoded_batch['teacher_api_logprobs'] = topk_logits
334+
encoded_batch['teacher_api_indices'] = topk_indices
335+
encoded_batch['teacher_logits'] = None
336+
else:
337+
encoded_batch['teacher_logits'] = teacher_logits
338+
339+
if opsd_batch is not None:
340+
encoded_batch['opsd_teacher_labels'] = opsd_teacher_labels
330341

331342
def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
332343
from swift.rlhf_trainers.gkd_trainer import fetch_teacher_logprobs
@@ -359,6 +370,10 @@ def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
359370
encoded_batch['teacher_api_indices'] = teacher_indices
360371
encoded_batch['teacher_logits'] = None
361372

373+
opsd_batch = encoded_batch.get('opsd_teacher_batch')
374+
if opsd_batch is not None:
375+
encoded_batch['opsd_teacher_labels'] = opsd_batch.get('labels')
376+
362377
def _replace_data_iterator(self, data_iterator):
363378
num_microbatches = self.args.num_microbatches
364379

@@ -401,8 +416,7 @@ def _replace_data_iterator(self, data_iterator):
401416
encoded_batch['opsd_teacher_batch'] = self._encode_batch(teacher_global_batch[start_idx:end_idx])
402417
encoded_batches.append(encoded_batch)
403418

404-
if not self._is_self_distillation:
405-
self._compute_teacher_logits(encoded_batches)
419+
self._compute_teacher_logits(encoded_batches)
406420

407421
# Increment step counter (used for deterministic random and weight sync)
408422
self._step += 1
@@ -721,7 +735,8 @@ def forward_step(self, data_iterator, model):
721735
teacher_logits = data.pop('teacher_logits', None)
722736
teacher_api_logprobs = data.pop('teacher_api_logprobs', None)
723737
teacher_api_indices = data.pop('teacher_api_indices', None)
724-
opsd_teacher_batch = data.pop('opsd_teacher_batch', None)
738+
opsd_teacher_labels = data.pop('opsd_teacher_labels', None)
739+
data.pop('opsd_teacher_batch', None)
725740
data = self._prepare_batch(data, vp_stage)
726741

727742
data.pop('loss_scale', None)
@@ -731,25 +746,6 @@ def forward_step(self, data_iterator, model):
731746
unwrapped_model.set_input_tensor(input_tensor)
732747
student_output = model(**data)
733748

734-
if self._is_self_distillation:
735-
if opsd_teacher_batch is not None:
736-
t_data = self._prepare_batch(opsd_teacher_batch, vp_stage)
737-
else:
738-
t_data = {k: v for k, v in data.items()}
739-
t_data.pop('loss_scale', None)
740-
opsd_teacher_labels = t_data.pop('labels', None)
741-
742-
adapter_contexts = []
743-
if self._teacher_use_disable_adapter:
744-
adapter_contexts = [m.disable_adapter() for m in self.peft_models]
745-
746-
with torch.no_grad(), ContextManagers(adapter_contexts):
747-
teacher_logits = forward_step_helper(self.args, unwrapped_model, t_data)
748-
if teacher_logits is not None:
749-
teacher_logits = teacher_logits.detach()
750-
else:
751-
opsd_teacher_labels = opsd_teacher_batch.pop('labels', None) if opsd_teacher_batch is not None else None
752-
753749
return student_output, partial(
754750
self.loss_func,
755751
labels=labels,

0 commit comments

Comments
 (0)