Skip to content

Conversation

@shenxiangzhuang
Copy link

@shenxiangzhuang shenxiangzhuang commented Oct 24, 2025

Fix the attention pad mask, with .unsqueeze(1).unsqueeze(3) will got wrong broadcast result. Currently, the code works despite this bug because we are using -10000 rather than -inf as usually:

score = score.masked_fill(mask == 0, -10000)

If we use -inf normally, we will got errors. Here is the analysis:

pad_mask_bad = (input_ids != padding_idx).unsqueeze(1).unsqueeze(3)

which yields shape (batch, 1, seq_len, 1). For a toy batch with padding at the end:

ids = torch.tensor([[5, 7, 0, 0]])
pad_mask_bad[0, 0] == tensor([
                            [ True],
                            [ True],
                            [False],
                            [False]])

During attention this mask must broadcast to (batch, heads, seq_q, seq_k). The third query row owns only a single False, so broadcasting replicates it across every key position:

row_2_after_broadcast = [False, False, False, False]

After applying the causal mask everything in that row stays False, so the logits become -inf, and softmax turns the row into nan. The failure only appeared when an entire suffix was padding, which explains why it slipped through basic smoke tests.

Step-by-step view.

  1. Build the naïve mask: pad_mask_bad.shape == (1, 1, 4, 1).
  2. Combine with the lower-triangular causal mask:
[[ True, False, False, False],
    [ True,  True, False, False],
    [False, False, False, False],
    [False, False, False, False]]
  1. Apply to logits → rows 2 and 3 contain only -infnan attention weights.

The fix. Keep the key axis explicit:

pad_mask_good = (input_ids != padding_idx).unsqueeze(1).unsqueeze(2)

Now the mask starts at (batch, 1, 1, seq_len) and broadcasting preserves the column-wise padding information:

[[ True, False, False, False],
 [ True,  True, False, False],
 [ True,  True, False, False],
 [ True,  True, False, False]]

Rows 2 and 3 still attend to the earlier valid tokens, so the logits stay finite and the model trains normally.

@shenxiangzhuang
Copy link
Author

Hi @hyunwoongko , can you review this at spare time? I want to make sure that I don't misunderstand the purpose of original implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant