Skip to content

Commit 794d61f

Browse files
authored
align teacher and student logit shape
1 parent 3a5845b commit 794d61f

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,21 @@ def _compute_loss(
112112
compute_ce_loss=compute_ce_loss,
113113
)
114114

115+
# If the teacher and student token size is different, pad student logits to match the teacher's.
116+
# This only applies to cases where they share exactly the same vocab and tokenizer just
117+
# that teacher logit is padded for some training efficiency such as
118+
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
119+
teacher_vocab_size = teacher_weight.shape[0]
120+
student_vocab_size = student_weight.shape[0]
121+
if teacher_vocab_size > student_vocab_size:
122+
pad_size = teacher_vocab_size - student_vocab_size
123+
pad_tensor = torch.zeros(
124+
(*student_logits_chunk.shape[:-1], pad_size),
125+
dtype=student_logits_chunk.dtype,
126+
device=student_logits_chunk.device
127+
)
128+
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
129+
115130
student_logits_chunk /= temperature
116131
teacher_logits_chunk /= temperature
117132

0 commit comments

Comments
 (0)