File tree Expand file tree Collapse file tree
src/liger_kernel/chunked_loss Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -112,6 +112,21 @@ def _compute_loss(
112112 compute_ce_loss = compute_ce_loss ,
113113 )
114114
115+ # If the teacher and student token size is different, pad student logits to match the teacher's.
116+ # This only applies to cases where they share exactly the same vocab and tokenizer just
117+ # that teacher logit is padded for some training efficiency such as
118+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
119+ teacher_vocab_size = teacher_weight .shape [0 ]
120+ student_vocab_size = student_weight .shape [0 ]
121+ if teacher_vocab_size > student_vocab_size :
122+ pad_size = teacher_vocab_size - student_vocab_size
123+ pad_tensor = torch .zeros (
124+ (* student_logits_chunk .shape [:- 1 ], pad_size ),
125+ dtype = student_logits_chunk .dtype ,
126+ device = student_logits_chunk .device
127+ )
128+ student_logits_chunk = torch .cat ([student_logits_chunk , pad_tensor ], dim = - 1 )
129+
115130 student_logits_chunk /= temperature
116131 teacher_logits_chunk /= temperature
117132
You can’t perform that action at this time.
0 commit comments