Skip to content

[SDPA] iree_linalg_ext.attention produces NaN for fully masked inputs, instead of zeros #24175

Description

@rkayaith

Attention is outputting NaNs instead of zeroes for fully masked-out input on bf16.

sdpa_mask_nan.mlir:

module @module {
  func.func @main(%attn_mask_4: !torch.vtensor<[1,1,1,1],bf16>, %k_1: !torch.vtensor<[1,1,1,2],bf16>, %q_0: !torch.vtensor<[1,1,1,2],bf16>, %v_2: !torch.vtensor<[1,1,1,2],bf16>) -> !torch.vtensor<[1,1,1,2],bf16> attributes {torch.assume_strict_symbolic_shapes} {
    %permute_Q_val_0_sdpa_fprop = torch.constant.int 0
    %permute_Q_val_1_sdpa_fprop = torch.constant.int 1
    %permute_Q_val_2_sdpa_fprop = torch.constant.int 2
    %permute_Q_val_3_sdpa_fprop = torch.constant.int 3
    %permute_Q_sdpa_fprop = torch.prim.ListConstruct %permute_Q_val_0_sdpa_fprop, %permute_Q_val_1_sdpa_fprop, %permute_Q_val_2_sdpa_fprop, %permute_Q_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %q_0_sdpa_fprop_perm = torch.aten.permute %q_0, %permute_Q_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>

    %permute_K_val_0_sdpa_fprop = torch.constant.int 0
    %permute_K_val_1_sdpa_fprop = torch.constant.int 1
    %permute_K_val_2_sdpa_fprop = torch.constant.int 2
    %permute_K_val_3_sdpa_fprop = torch.constant.int 3
    %permute_K_sdpa_fprop = torch.prim.ListConstruct %permute_K_val_0_sdpa_fprop, %permute_K_val_1_sdpa_fprop, %permute_K_val_2_sdpa_fprop, %permute_K_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %k_1_sdpa_fprop_perm = torch.aten.permute %k_1, %permute_K_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>

    %permute_V_val_0_sdpa_fprop = torch.constant.int 0
    %permute_V_val_1_sdpa_fprop = torch.constant.int 1
    %permute_V_val_2_sdpa_fprop = torch.constant.int 2
    %permute_V_val_3_sdpa_fprop = torch.constant.int 3
    %permute_V_sdpa_fprop = torch.prim.ListConstruct %permute_V_val_0_sdpa_fprop, %permute_V_val_1_sdpa_fprop, %permute_V_val_2_sdpa_fprop, %permute_V_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %v_2_sdpa_fprop_perm = torch.aten.permute %v_2, %permute_V_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>

    %permute_mask_val_0_sdpa_fprop = torch.constant.int 0
    %permute_mask_val_1_sdpa_fprop = torch.constant.int 1
    %permute_mask_val_2_sdpa_fprop = torch.constant.int 2
    %permute_mask_val_3_sdpa_fprop = torch.constant.int 3
    %permute_mask_sdpa_fprop = torch.prim.ListConstruct %permute_mask_val_0_sdpa_fprop, %permute_mask_val_1_sdpa_fprop, %permute_mask_val_2_sdpa_fprop, %permute_mask_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %attn_mask_4_sdpa_fprop_perm = torch.aten.permute %attn_mask_4, %permute_mask_sdpa_fprop : !torch.vtensor<[1,1,1,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,1],bf16>

    %dropout_sdpa_fprop = torch.constant.float 0.000000e+00
    %is_causal_sdpa_fprop = torch.constant.bool false
    %scale_sdpa_fprop = torch.constant.float 7.071068e-01
    %enable_gqa_sdpa_fprop = torch.constant.bool false
    %o_3_sdpa_fprop_perm = torch.aten.scaled_dot_product_attention %q_0_sdpa_fprop_perm, %k_1_sdpa_fprop_perm, %v_2_sdpa_fprop_perm, %attn_mask_4_sdpa_fprop_perm, %dropout_sdpa_fprop, %is_causal_sdpa_fprop, %scale_sdpa_fprop, %enable_gqa_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.vtensor<[1,1,1,2],bf16>, !torch.vtensor<[1,1,1,2],bf16>, !torch.vtensor<[1,1,1,1],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[1,1,1,2],bf16>

    %permute_O_val_0_sdpa_fprop = torch.constant.int 0
    %permute_O_val_1_sdpa_fprop = torch.constant.int 1
    %permute_O_val_2_sdpa_fprop = torch.constant.int 2
    %permute_O_val_3_sdpa_fprop = torch.constant.int 3
    %permute_O_sdpa_fprop = torch.prim.ListConstruct %permute_O_val_0_sdpa_fprop, %permute_O_val_1_sdpa_fprop, %permute_O_val_2_sdpa_fprop, %permute_O_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %o_3 = torch.aten.permute %o_3_sdpa_fprop_perm, %permute_O_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>

    return %o_3 : !torch.vtensor<[1,1,1,2],bf16>
  }
}
iree-compile sdpa_mask_nan.mlir \
  --iree-hal-target-backends=rocm --iree-rocm-target=mi300x \
  --iree-opt-level=O3 \
  -o sdpa_mask_nan.vmfb

iree-run-module --module=sdpa_mask_nan.vmfb --device=hip --function=main \
  --input="1x1x1x1xbf16=-Inf"    `# attn_mask (fully masked)` \
  --input="1x1x1x2xbf16=1,2"     `# k` \
  --input="1x1x1x2xbf16=3,4"     `# q` \
  --input="1x1x1x2xbf16=5,6"     `# v`

Output:

result[0]: hal.buffer_view
1x1x1x2xbf16=[[[NAN NAN]]]

Expected: 1x1x1x2xbf16=[[[0 0]]]

Source pytorch graph this was generated from, which is minimized from this test: https://github.com/pytorch/pytorch/blob/8cb62a8b4caf7a473870064c05d5de5ae511a2c5/test/test_transformers.py#L3148

import torch

q = torch.randn(1, 1, 1, 2, dtype=torch.bfloat16, device="cuda")
k = torch.randn(1, 1, 1, 2, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, 1, 1, 2, dtype=torch.bfloat16, device="cuda")
mask = torch.zeros(1, 1, dtype=torch.bool, device="cuda")  # all-False = all positions masked

out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
# out = [[[[nan, nan]]]]; expected zeros

Environment

  • iree-base-compiler==3.12.0rc20260417

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions