Skip to content

Commit eb263b5

Browse files
authored
[FPSan] make tests much faster (#10016)
The slowest tests in the suite previously took 200 seconds each and now take 3.3 seconds each. The slowness was the reference implementation using Python loops over scalar numpy code instead of vectorised numpy code. We also prune the number of test cases for tests with large Cartesian product parametrisations. Also, to minimise the amount of LLVM IR code generated by FPSan, we emit `scf::for` loops for `sin`, `cos`, and `exp2` instead of unrolled straight-line code: this makes sense as the principal bottleneck in FPSan use is compile-time rather than runtime.
1 parent df82d98 commit eb263b5

3 files changed

Lines changed: 186 additions & 82 deletions

File tree

lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,16 @@ Value castSignedIntValueToType(PatternRewriter &rewriter, Location loc, Value v,
443443
return v;
444444
}
445445

446+
Value castScalarIntToIntLike(PatternRewriter &rewriter, Location loc,
447+
Value scalar, Type targetTy) {
448+
auto elemTy = cast<IntegerType>(getElementType(targetTy));
449+
if (scalar.getType() != elemTy)
450+
scalar = castSignedIntValueToType(rewriter, loc, scalar, elemTy);
451+
if (isa<ShapedType>(targetTy))
452+
return tt::SplatOp::create(rewriter, loc, targetTy, scalar);
453+
return scalar;
454+
}
455+
446456
Value selectUIntConstantOnSign(PatternRewriter &rewriter, Location loc,
447457
Value signSource, uint64_t signMaskValue,
448458
uint64_t nonNegativeValue,
@@ -674,45 +684,60 @@ Value fpsanSRem(PatternRewriter &rewriter, Location loc, Value num, Value den) {
674684

675685
// Modular exponentiation in payload space; this preserves
676686
// exp2(a + b) = exp2(a) * exp2(b) under the integer rewrite.
677-
Value fpsanExp2FromI32(PatternRewriter &rewriter, Location loc, Value xI,
687+
Value fpsanExp2FromInt(PatternRewriter &rewriter, Location loc, Value xI,
678688
Type floatTy) {
689+
unsigned bitWidth = getIntBitwidth(xI.getType());
679690
auto one = getIntConstantLike(rewriter, loc, xI.getType(), 1);
680691
auto zero = getIntConstantLike(rewriter, loc, xI.getType(), 0);
681692
auto c = getIntConstantLike(rewriter, loc, xI.getType(), 0xa343836d);
682693

683-
Value y = one;
684-
for (int i = 0; i < 32; ++i) {
685-
y = arith::MulIOp::create(rewriter, loc, y, y);
686-
auto bit = getIntConstantLike(rewriter, loc, xI.getType(),
687-
int64_t(1ull << (31 - i)));
688-
auto masked = arith::AndIOp::create(rewriter, loc, xI, bit);
689-
auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
690-
masked, zero);
691-
auto factor = arith::SelectOp::create(rewriter, loc, isZero, one, c);
692-
y = arith::MulIOp::create(rewriter, loc, y, factor);
693-
}
694-
695-
return unembedToFloat(rewriter, loc, y, floatTy);
694+
auto lower =
695+
arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
696+
auto upper = arith::ConstantOp::create(rewriter, loc,
697+
rewriter.getI32IntegerAttr(bitWidth));
698+
auto step =
699+
arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1));
700+
auto topBit = arith::ConstantOp::create(
701+
rewriter, loc, rewriter.getI32IntegerAttr(bitWidth - 1));
702+
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, one);
703+
rewriter.setInsertionPointToStart(loop.getBody());
704+
705+
Value i = loop.getInductionVar();
706+
Value y = loop.getRegionIterArgs()[0];
707+
y = arith::MulIOp::create(rewriter, loc, y, y);
708+
Value bitIndex =
709+
arith::SubIOp::create(rewriter, loc, rewriter.getI32Type(), topBit, i);
710+
Value shift = castScalarIntToIntLike(rewriter, loc, bitIndex, xI.getType());
711+
Value bit = arith::ShLIOp::create(rewriter, loc, one, shift);
712+
auto masked = arith::AndIOp::create(rewriter, loc, xI, bit);
713+
auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
714+
masked, zero);
715+
auto factor = arith::SelectOp::create(rewriter, loc, isZero, one, c);
716+
y = arith::MulIOp::create(rewriter, loc, y, factor);
717+
scf::YieldOp::create(rewriter, loc, y);
718+
rewriter.setInsertionPointAfter(loop);
719+
720+
return unembedToFloat(rewriter, loc, loop.getResult(0), floatTy);
696721
}
697722

698723
Value fpsanExp2(PatternRewriter &rewriter, Location loc, Value input) {
699724
auto elemTy = dyn_cast<FloatType>(getElementType(input.getType()));
700-
if (!elemTy || elemTy.getWidth() != 32)
725+
if (!elemTy)
701726
return Value();
702-
return fpsanExp2FromI32(rewriter, loc, embedToInt(rewriter, loc, input),
727+
return fpsanExp2FromInt(rewriter, loc, embedToInt(rewriter, loc, input),
703728
input.getType());
704729
}
705730

706731
Value fpsanExp(PatternRewriter &rewriter, Location loc, Value input) {
707732
auto elemTy = dyn_cast<FloatType>(getElementType(input.getType()));
708-
if (!elemTy || elemTy.getWidth() != 32)
733+
if (!elemTy)
709734
return Value();
710735

711736
auto inputI = embedToInt(rewriter, loc, input);
712737
auto rcpLog2 =
713738
getU32ConstantLike(rewriter, loc, inputI.getType(), 0x236ee9bfu);
714739
auto scaledI = arith::MulIOp::create(rewriter, loc, inputI, rcpLog2);
715-
return fpsanExp2FromI32(rewriter, loc, scaledI, input.getType());
740+
return fpsanExp2FromInt(rewriter, loc, scaledI, input.getType());
716741
}
717742

718743
struct FpSanCosSin {
@@ -735,32 +760,47 @@ FpSanCosSin fpsanCosSinPayload(PatternRewriter &rewriter, Location loc,
735760
auto a = getUIntConstantLike(rewriter, loc, intTy, aValue);
736761
auto b = getUIntConstantLike(rewriter, loc, intTy, bValue);
737762

738-
Value c = one;
739-
Value s = zero;
740-
for (int bit = static_cast<int>(bitWidth) - 1; bit >= 0; --bit) {
741-
Value cc = arith::MulIOp::create(rewriter, loc, c, c);
742-
Value ss = arith::MulIOp::create(rewriter, loc, s, s);
743-
Value cDouble = arith::SubIOp::create(rewriter, loc, cc, ss);
744-
Value cs = arith::MulIOp::create(rewriter, loc, c, s);
745-
Value sDouble = arith::MulIOp::create(rewriter, loc, two, cs);
746-
747-
Value ac = arith::MulIOp::create(rewriter, loc, a, cDouble);
748-
Value bs = arith::MulIOp::create(rewriter, loc, b, sDouble);
749-
Value cInc = arith::SubIOp::create(rewriter, loc, ac, bs);
750-
Value as = arith::MulIOp::create(rewriter, loc, a, sDouble);
751-
Value bc = arith::MulIOp::create(rewriter, loc, b, cDouble);
752-
Value sInc = arith::AddIOp::create(rewriter, loc, as, bc);
753-
754-
auto bitMask =
755-
getUIntConstantLike(rewriter, loc, intTy, uint64_t{1} << bit);
756-
auto masked = arith::AndIOp::create(rewriter, loc, xI, bitMask);
757-
auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
758-
masked, zero);
759-
c = arith::SelectOp::create(rewriter, loc, isZero, cDouble, cInc);
760-
s = arith::SelectOp::create(rewriter, loc, isZero, sDouble, sInc);
761-
}
762-
763-
return {c, s};
763+
auto lower =
764+
arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
765+
auto upper = arith::ConstantOp::create(rewriter, loc,
766+
rewriter.getI32IntegerAttr(bitWidth));
767+
auto step =
768+
arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1));
769+
auto topBit = arith::ConstantOp::create(
770+
rewriter, loc, rewriter.getI32IntegerAttr(bitWidth - 1));
771+
SmallVector<Value> initArgs{one, zero};
772+
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, initArgs);
773+
rewriter.setInsertionPointToStart(loop.getBody());
774+
775+
Value bit = loop.getInductionVar();
776+
Value c = loop.getRegionIterArgs()[0];
777+
Value s = loop.getRegionIterArgs()[1];
778+
Value cc = arith::MulIOp::create(rewriter, loc, c, c);
779+
Value ss = arith::MulIOp::create(rewriter, loc, s, s);
780+
Value cDouble = arith::SubIOp::create(rewriter, loc, cc, ss);
781+
Value cs = arith::MulIOp::create(rewriter, loc, c, s);
782+
Value sDouble = arith::MulIOp::create(rewriter, loc, two, cs);
783+
784+
Value ac = arith::MulIOp::create(rewriter, loc, a, cDouble);
785+
Value bs = arith::MulIOp::create(rewriter, loc, b, sDouble);
786+
Value cInc = arith::SubIOp::create(rewriter, loc, ac, bs);
787+
Value as = arith::MulIOp::create(rewriter, loc, a, sDouble);
788+
Value bc = arith::MulIOp::create(rewriter, loc, b, cDouble);
789+
Value sInc = arith::AddIOp::create(rewriter, loc, as, bc);
790+
791+
Value bitIndex =
792+
arith::SubIOp::create(rewriter, loc, rewriter.getI32Type(), topBit, bit);
793+
Value shift = castScalarIntToIntLike(rewriter, loc, bitIndex, intTy);
794+
Value bitMask = arith::ShLIOp::create(rewriter, loc, one, shift);
795+
auto masked = arith::AndIOp::create(rewriter, loc, xI, bitMask);
796+
auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
797+
masked, zero);
798+
c = arith::SelectOp::create(rewriter, loc, isZero, cDouble, cInc);
799+
s = arith::SelectOp::create(rewriter, loc, isZero, sDouble, sInc);
800+
scf::YieldOp::create(rewriter, loc, ValueRange{c, s});
801+
rewriter.setInsertionPointAfter(loop);
802+
803+
return {loop.getResult(0), loop.getResult(1)};
764804
}
765805

766806
Value fpsanCos(PatternRewriter &rewriter, Location loc, Value input) {

python/test/conftest.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
1-
import pytest
1+
from collections import defaultdict
2+
import hashlib
23
import tempfile
34

5+
import pytest
6+
7+
8+
def _top_level_test_key(item):
9+
nodeid = item.nodeid
10+
bracket = nodeid.find("[")
11+
return nodeid if bracket == -1 else nodeid[:bracket]
12+
13+
14+
def _case_key(item):
15+
return item.name
16+
17+
18+
def _sha256_hex(s: str) -> str:
19+
return hashlib.sha256(s.encode("utf-8")).hexdigest()
20+
421

522
def pytest_configure(config):
623
# If pytest-sugar is not active, enable instafail
@@ -10,6 +27,35 @@ def pytest_configure(config):
1027

1128
def pytest_addoption(parser):
1229
parser.addoption("--device", action="store", default="cuda")
30+
parser.addoption(
31+
"--max-cases-per-test",
32+
action="store",
33+
type=int,
34+
default=100,
35+
help="Maximum number of cases per top-level test",
36+
)
37+
38+
39+
def pytest_collection_modifyitems(config, items):
40+
max_cases = config.getoption("--max-cases-per-test")
41+
if max_cases <= 0:
42+
return
43+
44+
groups = defaultdict(list)
45+
for item in items:
46+
groups[_top_level_test_key(item)].append(item)
47+
48+
kept = []
49+
deselected = []
50+
for group in groups.values():
51+
ordered = sorted(group, key=lambda item: _sha256_hex(_case_key(item)))
52+
kept.extend(ordered[:max_cases])
53+
deselected.extend(ordered[max_cases:])
54+
55+
if deselected:
56+
config.hook.pytest_deselected(items=deselected)
57+
58+
items[:] = kept
1359

1460

1561
@pytest.fixture

python/test/gluon/test_fpsan.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,31 +1333,47 @@ def _mm_scaled_payload_u32(a_u8: np.ndarray, b_u8: np.ndarray, a_scale_u8: np.nd
13331333
assert a_scale.shape == (m, k // 32)
13341334
assert b_scale.shape == (n, k // 32)
13351335

1336-
def unpack(data: np.ndarray, row: int, col: int, pack: int, pack_axis: int) -> np.uint16:
1336+
def unpack_payload_matrix(data: np.ndarray, pack: int, pack_axis: int) -> np.ndarray:
13371337
if pack == 1:
1338-
return np.uint16(data[row, col])
1339-
return np.uint16(_unpack_element(data, row, col, pack, pack_axis=pack_axis))
1338+
return data.astype(np.uint64)
1339+
assert pack == 2
1340+
if pack_axis == 1:
1341+
out = np.empty((data.shape[0], data.shape[1] * pack), dtype=np.uint64)
1342+
out[:, 0::2] = data.astype(np.uint64) & np.uint64(0x0F)
1343+
out[:, 1::2] = (data.astype(np.uint64) >> np.uint64(4)) & np.uint64(0x0F)
1344+
return out
1345+
out = np.empty((data.shape[0] * pack, data.shape[1]), dtype=np.uint64)
1346+
out[0::2, :] = data.astype(np.uint64) & np.uint64(0x0F)
1347+
out[1::2, :] = (data.astype(np.uint64) >> np.uint64(4)) & np.uint64(0x0F)
1348+
return out
1349+
1350+
def compute_payload_matrix(data: np.ndarray) -> np.ndarray:
1351+
if elem_type in ("e4m3", "e5m2"):
1352+
one_bits = 0x38 if elem_type == "e4m3" else 0x3C
1353+
payload = _mix_float_bits_to_payload_u64(data, 8, one_bits)
1354+
return _signed_cast_payload_u64(payload, 8, 16)
1355+
return data & np.uint64(0xFFFF)
1356+
1357+
def scale_payload_matrix(raw_scale: np.ndarray) -> np.ndarray:
1358+
raw_bf16 = (raw_scale & np.uint64(0xFF)) << np.uint64(7)
1359+
return _mix_float_bits_to_payload_u64(raw_bf16, 16, 0x3F80)
13401360

1341-
out = np.empty((m, n), dtype=np.uint64)
1342-
compute_type = "bf16"
1361+
a_payload = compute_payload_matrix(unpack_payload_matrix(a_u8, a_pack, pack_axis=1))
1362+
b_payload = compute_payload_matrix(unpack_payload_matrix(b_u8, b_pack, pack_axis=0))
1363+
a_scale_payload = scale_payload_matrix(a_scale)
1364+
b_scale_payload = scale_payload_matrix(b_scale)
1365+
1366+
out = c_u.copy() if c_u is not None else np.zeros((m, n), dtype=np.uint64)
13431367
compute_mask = np.uint64(0xFFFF)
13441368
mask32 = np.uint64(0xFFFFFFFF)
1345-
for i in range(m):
1346-
for j in range(n):
1347-
s = c_u[i, j] if c_u is not None else 0
1348-
for kk in range(k):
1349-
a_val = unpack(a_u8, i, kk, a_pack, pack_axis=1)
1350-
b_val = unpack(b_u8, kk, j, b_pack, pack_axis=0)
1351-
a_val = _dot_scaled_compute_payload_elem(np.uint64(a_val), elem_type, compute_type)
1352-
b_val = _dot_scaled_compute_payload_elem(np.uint64(b_val), elem_type, compute_type)
1353-
a_scale_val = _dot_scaled_scale_payload(a_scale[i, kk // 32], compute_type)
1354-
b_scale_val = _dot_scaled_scale_payload(b_scale[j, kk // 32], compute_type)
1355-
lhs = (a_val * a_scale_val) & compute_mask
1356-
rhs = (b_val * b_scale_val) & compute_mask
1357-
lhs = _signed_cast_payload_scalar(lhs, 16, 32)
1358-
rhs = _signed_cast_payload_scalar(rhs, 16, 32)
1359-
s = (s + ((np.uint64(lhs) * np.uint64(rhs)) & mask32)) & mask32
1360-
out[i, j] = s
1369+
for group in range(k // 32):
1370+
start = group * 32
1371+
end = start + 32
1372+
lhs = (a_payload[:, start:end] * a_scale_payload[:, group:group + 1]) & compute_mask
1373+
rhs = (b_payload[start:end, :] * b_scale_payload[:, group][None, :]) & compute_mask
1374+
lhs = _signed_cast_payload_u64(lhs, 16, 32)
1375+
rhs = _signed_cast_payload_u64(rhs, 16, 32)
1376+
out = (out + (lhs @ rhs)) & mask32
13611377
return _unmix_payload_u32_to_f32_bits_i32(out.astype(np.uint32))
13621378

13631379

@@ -1759,31 +1775,33 @@ def test_reduction(device, fresh_knobs):
17591775
_require_cuda_backend(device)
17601776

17611777
@triton.jit
1762-
def reduce_kernel(a_ptr, c_ptr, M: tl.constexpr, N: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr,
1763-
ORDER: tl.constexpr):
1764-
a_ptrs = a_ptr + (tl.arange(0, M)[:, None] * stride_am + (tl.arange(0, N)[None, :]) * stride_ak)
1778+
def reduce_kernel(a_ptr, c_ptr, M: tl.constexpr, N: tl.constexpr, stride_ak: tl.constexpr, stride_am: tl.constexpr,
1779+
stride_an: tl.constexpr, ORDER: tl.constexpr):
1780+
1781+
a_ptr += tl.program_id(0).to(tl.int64) * stride_ak
1782+
c_ptr += tl.program_id(0).to(tl.int64)
1783+
a_ptrs = a_ptr + (tl.arange(0, M)[:, None] * stride_am + (tl.arange(0, N)[None, :]) * stride_an)
17651784
a = tl.load(a_ptrs)
17661785
r1 = tl.sum(a, axis=ORDER)
1767-
r2 = tl.sum(r1, axis=ORDER - 1)
1786+
r2 = tl.sum(r1, axis=0)
17681787
tl.store(c_ptr, r2)
17691788

1770-
M, N = 512, 512
1789+
# we run K parallel tests so as to make non-associativity much more
1790+
# likely to manifest:
1791+
K, M, N = 100, 128, 128
17711792
torch.manual_seed(0)
1772-
a = torch.randn((M, N), dtype=torch.float32, device="cuda")
1773-
# Make non-associativity visible and deterministic: large + tiny magnitudes.
1774-
a[:, :64] *= 1e10
1775-
a[:, 64:] *= 1e-10
1776-
c1 = torch.empty((1, ), dtype=torch.float32).to('cuda')
1777-
c2 = torch.empty((1, ), dtype=torch.float32).to('cuda')
1778-
1779-
reduce_kernel[(1, )](a, c1, M=M, N=N, stride_am=a.stride(0), stride_ak=a.stride(1), ORDER=0)
1780-
reduce_kernel[(1, )](a, c2, M=M, N=N, stride_am=a.stride(0), stride_ak=a.stride(1), ORDER=1)
1793+
a = torch.randn((K, M, N), dtype=torch.float32, device="cuda")
1794+
c1 = torch.empty((K, ), dtype=torch.float32).to('cuda')
1795+
c2 = torch.empty((K, ), dtype=torch.float32).to('cuda')
1796+
1797+
reduce_kernel[(K, )](a, c1, M, N, a.stride(0), a.stride(1), a.stride(2), ORDER=0)
1798+
reduce_kernel[(K, )](a, c2, M, N, a.stride(0), a.stride(1), a.stride(2), ORDER=1)
17811799
assert not _payload_equal(c1, c2)
17821800

17831801
fresh_knobs.compilation.instrumentation_mode = "fpsan"
17841802

1785-
reduce_kernel[(1, )](a, c1, M=M, N=N, stride_am=a.stride(0), stride_ak=a.stride(1), ORDER=0)
1786-
reduce_kernel[(1, )](a, c2, M=M, N=N, stride_am=a.stride(0), stride_ak=a.stride(1), ORDER=1)
1803+
reduce_kernel[(K, )](a, c1, M, N, a.stride(0), a.stride(1), a.stride(2), ORDER=0)
1804+
reduce_kernel[(K, )](a, c2, M, N, a.stride(0), a.stride(1), a.stride(2), ORDER=1)
17871805
assert _payload_equal(c1, c2)
17881806

17891807

0 commit comments

Comments
 (0)