Skip to content

packing loss 的归一化问题 #12

@Chandler-Bing

Description

@Chandler-Bing

这里的loss计算是不是应该归一化一下

loss = (loss * shift_weights).sum() -> loss = (loss * shift_weights).sum() / shift_weights.sum()

把loss归一化到token粒度
前一种方式,loss的scale偏大,而且反向传播梯度也会偏大。而且极限情况下,假设每个样本只有1个token,这个batch的loss会爆炸

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions