Skip to content

Commit e22d426

Browse files
[FPSAN] Preserve NaN payload bits in encoding
1 parent 4b569dc commit e22d426

3 files changed

Lines changed: 48 additions & 7 deletions

File tree

lib/Conversion/TritonInstrumentToLLVM/FpSanToLLVM.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,14 @@ Value mixFloatToInt(ConversionPatternRewriter &rewriter, Location loc, Value u,
123123
PayloadMixConfig cfg = getPayloadMixConfig(floatTy);
124124
Value signFlip =
125125
selectUIntConstantOnSign(rewriter, loc, u, cfg.signMask, 0, cfg.signMask);
126-
Value x = b.xor_(u, signFlip);
127126
Value mulA = createUIntConstant(rewriter, loc, u.getType(), cfg.mulA);
128127
Value magMask = createUIntConstant(rewriter, loc, u.getType(), cfg.magMask);
129-
Value yMul = b.mul(x, mulA);
128+
// Avoid patterns that InstCombine rewrites to `llvm.fabs`. LLVM specifies
129+
// that `llvm.fabs` preserves the NaN quiet/signaling bit and payload, but
130+
// NVPTX lowers it to PTX `abs.f32`, whose NaN result is unspecified. On
131+
// Blackwell, `abs.f32` is observed to canonicalize signaling NaNs, corrupting
132+
// FPSan payloads.
133+
Value yMul = b.mul(u, mulA);
130134
Value y = b.and_(yMul, magMask);
131135
Value z = xorShiftRight(rewriter, loc, y, cfg.shift);
132136
Value mulB = selectUIntConstantOnSign(rewriter, loc, u, cfg.signMask,
@@ -142,11 +146,10 @@ Value unmixIntToFloat(ConversionPatternRewriter &rewriter, Location loc,
142146
PayloadMixConfig cfg = getPayloadMixConfig(floatTy);
143147
Value signFlip =
144148
selectUIntConstantOnSign(rewriter, loc, v, cfg.signMask, 0, cfg.signMask);
145-
Value w = b.xor_(v, signFlip);
146149
Value magMask = createUIntConstant(rewriter, loc, v.getType(), cfg.magMask);
147150
Value mulBInv = selectUIntConstantOnSign(rewriter, loc, v, cfg.signMask,
148151
cfg.mulBPosInv, cfg.mulBNegInv);
149-
Value zMul = b.mul(w, mulBInv);
152+
Value zMul = b.mul(v, mulBInv);
150153
Value z = b.and_(zMul, magMask);
151154
Value y = inverseXorShiftRight(rewriter, loc, z, cfg);
152155
Value mulAInv = createUIntConstant(rewriter, loc, v.getType(), cfg.mulAInv);

python/test/gluon/test_fpsan.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2703,6 +2703,43 @@ def loop_sum_kernel(x_ptr, out_ptr, N: tl.constexpr):
27032703
_assert_payload_equal(reduce_out, loop_out)
27042704

27052705

2706+
def test_f32_loop_preserves_snan_payload(device, fresh_knobs):
2707+
_require_cuda_backend(device)
2708+
if not is_cuda():
2709+
pytest.skip("regression is specific to NVPTX fabs lowering")
2710+
2711+
@triton.jit
2712+
def sum_kernel(x_ptr, out_ptr, BLOCK: tl.constexpr):
2713+
offsets = tl.arange(0, BLOCK)
2714+
acc = tl.zeros((BLOCK, ), tl.float32)
2715+
for i in range(3):
2716+
acc += tl.load(x_ptr + i * BLOCK + offsets)
2717+
tl.store(out_ptr + offsets, acc)
2718+
2719+
fresh_knobs.compilation.instrumentation_mode = "fpsan"
2720+
fresh_knobs.compilation.always_compile = True
2721+
2722+
block = 128
2723+
# The first two finite values sum to an sNaN; the zero row forces it through the next loop embed.
2724+
input_bits = np.zeros((3, block), dtype=np.int32)
2725+
input_bits[0].fill(0x1B0F577C)
2726+
input_bits[1].fill(0x65E031B7)
2727+
assert np.isfinite(input_bits.view(np.float32)).all()
2728+
x = torch.tensor(input_bits, dtype=torch.int32, device="cuda")
2729+
out = torch.empty((block, ), dtype=torch.int32, device="cuda")
2730+
sum_kernel[(1, )](
2731+
triton.TensorWrapper(x, dtype=torch.float32),
2732+
triton.TensorWrapper(out, dtype=torch.float32),
2733+
BLOCK=block,
2734+
num_warps=1,
2735+
)
2736+
2737+
expected = _expected_add_i32(input_bits[0], input_bits[1])
2738+
expected = _expected_add_i32(expected, input_bits[2])
2739+
assert np.all(_as_u32(expected) == np.uint32(0x7FA12345))
2740+
_assert_payload_equal(out, expected)
2741+
2742+
27062743
@pytest.mark.skipif(not (is_hip_cdna3() or is_hip_cdna4()), reason="Requires CDNA3 or CDNA4")
27072744
@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k", "instr_m", "instr_n", "instr_k", "k_width"),
27082745
_MFMA_DOT_CASES)

test/Conversion/tritoninstrument_to_llvm.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ tt.func private @experimental_gsan_tensordesc_info(
115115
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} {
116116
// CHECK-LABEL: @experimental_fpsan_embed
117117
// CHECK-NOT: tti.experimental_fpsan_embed
118-
// CHECK: llvm.bitcast %arg0 : f32 to i32
119-
// CHECK: llvm.mul
118+
// CHECK: %[[RAW:.*]] = llvm.bitcast %arg0 : f32 to i32
119+
// CHECK-NOT: llvm.inline_asm
120+
// CHECK: llvm.mul %[[RAW]],
120121
// CHECK: llvm.xor
121122
tt.func private @experimental_fpsan_embed(%arg0: f32) -> i32 {
122123
%0 = tti.experimental_fpsan_embed %arg0 : (f32) -> i32
@@ -129,7 +130,7 @@ tt.func private @experimental_fpsan_embed(%arg0: f32) -> i32 {
129130
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} {
130131
// CHECK-LABEL: @experimental_fpsan_unembed
131132
// CHECK-NOT: tti.experimental_fpsan_unembed
132-
// CHECK: llvm.mul
133+
// CHECK: llvm.mul %arg0,
133134
// CHECK: llvm.xor
134135
// CHECK: llvm.bitcast %{{.*}} : i32 to f32
135136
tt.func private @experimental_fpsan_unembed(%arg0: i32) -> f32 {

0 commit comments

Comments
 (0)