diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 5b0ef192a..26a942512 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -115,6 +115,21 @@ def _compute_loss( student_logits_chunk /= temperature teacher_logits_chunk /= temperature + # If the teacher and student token size is different, pad student logits to match the teacher's. + # This only applies to cases where they share exactly the same vocab and tokenizer just + # that teacher logit is padded for some training efficiency such as + # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2 + teacher_vocab_size = teacher_weight.shape[0] + student_vocab_size = student_weight.shape[0] + if teacher_vocab_size > student_vocab_size: + pad_size = teacher_vocab_size - student_vocab_size + pad_tensor = torch.zeros( + (*student_logits_chunk.shape[:-1], pad_size), + dtype=student_logits_chunk.dtype, + device=student_logits_chunk.device, + ) + student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1) + hard_loss /= full_target.shape[0] soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)