|
if special_token_id is not None: |
|
special_token_mask = (labels == special_token_id) & loss_mask |
|
per_token_logps[special_token_mask][chosen_count:] = -per_token_logps[ |
|
special_token_mask |
|
][chosen_count:] |
Вот в этом куске слайсинг, похоже, работает не так как ожидается - вместо special токенов из rejected сэмплов он выбирает special токены из всего батча начиная с chosen_count.
Когда тензор per_token_logps индексируется булевой маской special_token_mask, то в результате получается 1д тензор, и размерность батча теряется.
per_token_logps shape: torch.Size([4, 972]) # (B, seq)
special_token_mask shape: torch.Size([4, 972]) # (B, seq)
per_token_logps[special_token_mask] shape: torch.Size([8]) # 1д тензор всех special токенов из всего батча
per_token_logps[special_token_mask][chosen_count:] shape: torch.Size([6]) # берём из 1д тензора все токены начиная с chosen_count
Чтобы взять special токены из rejected сэмплов, можно написать, например, так:
rejected_logps = per_token_logps[chosen_count:] # (B, seq)
loss_mask = labels != label_pad_token_id
rejected_mask = loss_mask[chosen_count:]
special_mask_rej = (labels[chosen_count:] == special_token_id) & rejected_mask
rejected_logps[special_mask_rej] = -rejected_logps[special_mask_rej]
Эта же проблема, соответственно, и в других местах, где проводится похожий слайсинг:
|
# Winsorize extremal values for rejected tokens |
|
# Winsorize extremal values for chosen tokens |
|
# Clip minimum logprob for rejected tokens |
effective_llm_alignment/src/trainers/smpo_trainer.py
Lines 845 to 849 in a03cad4
Вот в этом куске слайсинг, похоже, работает не так как ожидается - вместо special токенов из rejected сэмплов он выбирает special токены из всего батча начиная с chosen_count.
Когда тензор per_token_logps индексируется булевой маской special_token_mask, то в результате получается 1д тензор, и размерность батча теряется.
Чтобы взять special токены из rejected сэмплов, можно написать, например, так:
Эта же проблема, соответственно, и в других местах, где проводится похожий слайсинг:
effective_llm_alignment/src/trainers/smpo_trainer.py
Line 851 in a03cad4
effective_llm_alignment/src/trainers/smpo_trainer.py
Line 866 in a03cad4
effective_llm_alignment/src/trainers/smpo_trainer.py
Line 881 in a03cad4