Attention is outputting NaNs instead of zeroes for fully masked-out input on bf16.
module @module {
func.func @main(%attn_mask_4: !torch.vtensor<[1,1,1,1],bf16>, %k_1: !torch.vtensor<[1,1,1,2],bf16>, %q_0: !torch.vtensor<[1,1,1,2],bf16>, %v_2: !torch.vtensor<[1,1,1,2],bf16>) -> !torch.vtensor<[1,1,1,2],bf16> attributes {torch.assume_strict_symbolic_shapes} {
%permute_Q_val_0_sdpa_fprop = torch.constant.int 0
%permute_Q_val_1_sdpa_fprop = torch.constant.int 1
%permute_Q_val_2_sdpa_fprop = torch.constant.int 2
%permute_Q_val_3_sdpa_fprop = torch.constant.int 3
%permute_Q_sdpa_fprop = torch.prim.ListConstruct %permute_Q_val_0_sdpa_fprop, %permute_Q_val_1_sdpa_fprop, %permute_Q_val_2_sdpa_fprop, %permute_Q_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%q_0_sdpa_fprop_perm = torch.aten.permute %q_0, %permute_Q_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>
%permute_K_val_0_sdpa_fprop = torch.constant.int 0
%permute_K_val_1_sdpa_fprop = torch.constant.int 1
%permute_K_val_2_sdpa_fprop = torch.constant.int 2
%permute_K_val_3_sdpa_fprop = torch.constant.int 3
%permute_K_sdpa_fprop = torch.prim.ListConstruct %permute_K_val_0_sdpa_fprop, %permute_K_val_1_sdpa_fprop, %permute_K_val_2_sdpa_fprop, %permute_K_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%k_1_sdpa_fprop_perm = torch.aten.permute %k_1, %permute_K_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>
%permute_V_val_0_sdpa_fprop = torch.constant.int 0
%permute_V_val_1_sdpa_fprop = torch.constant.int 1
%permute_V_val_2_sdpa_fprop = torch.constant.int 2
%permute_V_val_3_sdpa_fprop = torch.constant.int 3
%permute_V_sdpa_fprop = torch.prim.ListConstruct %permute_V_val_0_sdpa_fprop, %permute_V_val_1_sdpa_fprop, %permute_V_val_2_sdpa_fprop, %permute_V_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%v_2_sdpa_fprop_perm = torch.aten.permute %v_2, %permute_V_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>
%permute_mask_val_0_sdpa_fprop = torch.constant.int 0
%permute_mask_val_1_sdpa_fprop = torch.constant.int 1
%permute_mask_val_2_sdpa_fprop = torch.constant.int 2
%permute_mask_val_3_sdpa_fprop = torch.constant.int 3
%permute_mask_sdpa_fprop = torch.prim.ListConstruct %permute_mask_val_0_sdpa_fprop, %permute_mask_val_1_sdpa_fprop, %permute_mask_val_2_sdpa_fprop, %permute_mask_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%attn_mask_4_sdpa_fprop_perm = torch.aten.permute %attn_mask_4, %permute_mask_sdpa_fprop : !torch.vtensor<[1,1,1,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,1],bf16>
%dropout_sdpa_fprop = torch.constant.float 0.000000e+00
%is_causal_sdpa_fprop = torch.constant.bool false
%scale_sdpa_fprop = torch.constant.float 7.071068e-01
%enable_gqa_sdpa_fprop = torch.constant.bool false
%o_3_sdpa_fprop_perm = torch.aten.scaled_dot_product_attention %q_0_sdpa_fprop_perm, %k_1_sdpa_fprop_perm, %v_2_sdpa_fprop_perm, %attn_mask_4_sdpa_fprop_perm, %dropout_sdpa_fprop, %is_causal_sdpa_fprop, %scale_sdpa_fprop, %enable_gqa_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.vtensor<[1,1,1,2],bf16>, !torch.vtensor<[1,1,1,2],bf16>, !torch.vtensor<[1,1,1,1],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[1,1,1,2],bf16>
%permute_O_val_0_sdpa_fprop = torch.constant.int 0
%permute_O_val_1_sdpa_fprop = torch.constant.int 1
%permute_O_val_2_sdpa_fprop = torch.constant.int 2
%permute_O_val_3_sdpa_fprop = torch.constant.int 3
%permute_O_sdpa_fprop = torch.prim.ListConstruct %permute_O_val_0_sdpa_fprop, %permute_O_val_1_sdpa_fprop, %permute_O_val_2_sdpa_fprop, %permute_O_val_3_sdpa_fprop : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%o_3 = torch.aten.permute %o_3_sdpa_fprop_perm, %permute_O_sdpa_fprop : !torch.vtensor<[1,1,1,2],bf16>, !torch.list<int> -> !torch.vtensor<[1,1,1,2],bf16>
return %o_3 : !torch.vtensor<[1,1,1,2],bf16>
}
}
result[0]: hal.buffer_view
1x1x1x2xbf16=[[[NAN NAN]]]
import torch
q = torch.randn(1, 1, 1, 2, dtype=torch.bfloat16, device="cuda")
k = torch.randn(1, 1, 1, 2, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, 1, 1, 2, dtype=torch.bfloat16, device="cuda")
mask = torch.zeros(1, 1, dtype=torch.bool, device="cuda") # all-False = all positions masked
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
# out = [[[[nan, nan]]]]; expected zeros
Attention is outputting NaNs instead of zeroes for fully masked-out input on bf16.
sdpa_mask_nan.mlir:Output:
Expected:
1x1x1x2xbf16=[[[0 0]]]Source pytorch graph this was generated from, which is minimized from this test: https://github.com/pytorch/pytorch/blob/8cb62a8b4caf7a473870064c05d5de5ae511a2c5/test/test_transformers.py#L3148
Environment
iree-base-compiler==3.12.0rc20260417