Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@ def _compute_loss(
student_logits_chunk /= temperature
teacher_logits_chunk /= temperature

# If the teacher and student token size is different, pad student logits to match the teacher's.
# This only applies to cases where they share exactly the same vocab and tokenizer just
# that teacher logit is padded for some training efficiency such as
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
teacher_vocab_size = teacher_weight.shape[0]
student_vocab_size = student_weight.shape[0]
if teacher_vocab_size > student_vocab_size:
pad_size = teacher_vocab_size - student_vocab_size
pad_tensor = torch.zeros(
(*student_logits_chunk.shape[:-1], pad_size),
dtype=student_logits_chunk.dtype,
device=student_logits_chunk.device,
)
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
Expand Down
Loading