@@ -301,32 +301,43 @@ def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optiona
301301 self ._compute_teacher_logits_local (encoded_batches , vp_stage )
302302
303303 def _compute_teacher_logits_local (self , encoded_batches : List [Dict ], vp_stage : Optional [int ] = None ) -> None :
304- teacher_model = self .teacher_models [vp_stage or 0 ]
305304 topk = self .gkd_logits_topk
306305
307- for encoded_batch in encoded_batches :
308- # Use OPSD teacher batch if available, otherwise use student batch
309- opsd_batch = encoded_batch .get ('opsd_teacher_batch' )
310- source = opsd_batch if opsd_batch is not None else encoded_batch
311- teacher_batch = {
312- k : v .clone () if isinstance (v , torch .Tensor ) else v
313- for k , v in source .items () if k not in ('data_source' , 'opsd_teacher_batch' )
314- }
315- teacher_data = self ._prepare_batch (teacher_batch )
316- teacher_data .pop ('loss_scale' , None )
317- teacher_data .pop ('labels' , None )
318- with self .load_teacher_model_context (), torch .no_grad ():
306+ if self ._is_self_distillation :
307+ teacher_model = self .unwrapped_models [0 ]
308+ adapter_contexts = []
309+ if self ._teacher_use_disable_adapter :
310+ adapter_contexts = [m .disable_adapter () for m in self .peft_models ]
311+ outer_context = ContextManagers (adapter_contexts )
312+ else :
313+ teacher_model = self .teacher_models [vp_stage or 0 ]
314+ outer_context = self .load_teacher_model_context ()
315+
316+ with torch .no_grad (), outer_context :
317+ for encoded_batch in encoded_batches :
318+ opsd_batch = encoded_batch .get ('opsd_teacher_batch' )
319+ source = opsd_batch if opsd_batch is not None else encoded_batch
320+ teacher_batch = {
321+ k : v .clone () if isinstance (v , torch .Tensor ) else v
322+ for k , v in source .items () if k not in ('data_source' , 'opsd_teacher_batch' )
323+ }
324+ teacher_data = self ._prepare_batch (teacher_batch )
325+ teacher_data .pop ('loss_scale' , None )
326+ opsd_teacher_labels = teacher_data .pop ('labels' , None )
319327 teacher_logits = forward_step_helper (self .args , teacher_model , teacher_data )
320328 if teacher_logits is not None :
321329 teacher_logits = teacher_logits .detach ()
322330
323- if topk is not None and teacher_logits is not None :
324- topk_logits , topk_indices = self ._vocab_parallel_topk (teacher_logits , k = topk )
325- encoded_batch ['teacher_api_logprobs' ] = topk_logits
326- encoded_batch ['teacher_api_indices' ] = topk_indices
327- encoded_batch ['teacher_logits' ] = None
328- else :
329- encoded_batch ['teacher_logits' ] = teacher_logits
331+ if topk is not None and teacher_logits is not None :
332+ topk_logits , topk_indices = self ._vocab_parallel_topk (teacher_logits , k = topk )
333+ encoded_batch ['teacher_api_logprobs' ] = topk_logits
334+ encoded_batch ['teacher_api_indices' ] = topk_indices
335+ encoded_batch ['teacher_logits' ] = None
336+ else :
337+ encoded_batch ['teacher_logits' ] = teacher_logits
338+
339+ if opsd_batch is not None :
340+ encoded_batch ['opsd_teacher_labels' ] = opsd_teacher_labels
330341
331342 def _compute_teacher_logits_from_api (self , encoded_batches : List [Dict ]) -> None :
332343 from swift .rlhf_trainers .gkd_trainer import fetch_teacher_logprobs
@@ -359,6 +370,10 @@ def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
359370 encoded_batch ['teacher_api_indices' ] = teacher_indices
360371 encoded_batch ['teacher_logits' ] = None
361372
373+ opsd_batch = encoded_batch .get ('opsd_teacher_batch' )
374+ if opsd_batch is not None :
375+ encoded_batch ['opsd_teacher_labels' ] = opsd_batch .get ('labels' )
376+
362377 def _replace_data_iterator (self , data_iterator ):
363378 num_microbatches = self .args .num_microbatches
364379
@@ -401,8 +416,7 @@ def _replace_data_iterator(self, data_iterator):
401416 encoded_batch ['opsd_teacher_batch' ] = self ._encode_batch (teacher_global_batch [start_idx :end_idx ])
402417 encoded_batches .append (encoded_batch )
403418
404- if not self ._is_self_distillation :
405- self ._compute_teacher_logits (encoded_batches )
419+ self ._compute_teacher_logits (encoded_batches )
406420
407421 # Increment step counter (used for deterministic random and weight sync)
408422 self ._step += 1
@@ -721,7 +735,8 @@ def forward_step(self, data_iterator, model):
721735 teacher_logits = data .pop ('teacher_logits' , None )
722736 teacher_api_logprobs = data .pop ('teacher_api_logprobs' , None )
723737 teacher_api_indices = data .pop ('teacher_api_indices' , None )
724- opsd_teacher_batch = data .pop ('opsd_teacher_batch' , None )
738+ opsd_teacher_labels = data .pop ('opsd_teacher_labels' , None )
739+ data .pop ('opsd_teacher_batch' , None )
725740 data = self ._prepare_batch (data , vp_stage )
726741
727742 data .pop ('loss_scale' , None )
@@ -731,25 +746,6 @@ def forward_step(self, data_iterator, model):
731746 unwrapped_model .set_input_tensor (input_tensor )
732747 student_output = model (** data )
733748
734- if self ._is_self_distillation :
735- if opsd_teacher_batch is not None :
736- t_data = self ._prepare_batch (opsd_teacher_batch , vp_stage )
737- else :
738- t_data = {k : v for k , v in data .items ()}
739- t_data .pop ('loss_scale' , None )
740- opsd_teacher_labels = t_data .pop ('labels' , None )
741-
742- adapter_contexts = []
743- if self ._teacher_use_disable_adapter :
744- adapter_contexts = [m .disable_adapter () for m in self .peft_models ]
745-
746- with torch .no_grad (), ContextManagers (adapter_contexts ):
747- teacher_logits = forward_step_helper (self .args , unwrapped_model , t_data )
748- if teacher_logits is not None :
749- teacher_logits = teacher_logits .detach ()
750- else :
751- opsd_teacher_labels = opsd_teacher_batch .pop ('labels' , None ) if opsd_teacher_batch is not None else None
752-
753749 return student_output , partial (
754750 self .loss_func ,
755751 labels = labels ,
0 commit comments