Skip to content

Commit 746064c

Browse files
pawelszczerbukrootcodex
authored
[FPSAN] Fix crash on incorrect layout for tmem copy (#10046)
TMEMCopy pattern is using stale tmem encoding. This may cause a crash in the validator if the encodings mismatch. --------- Co-authored-by: root <root@codex-gb200-0.brix.pawelszczerbuk.svc.cluster.local> Co-authored-by: Codex <noreply@openai.com>
1 parent 6ea516a commit 746064c

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,9 +1856,11 @@ struct TMEMCopyPattern : public OpRewritePattern<ttng::TMEMCopyOp> {
18561856

18571857
auto loc = op.getLoc();
18581858
auto srcMemTy = cast<ttg::MemDescType>(op.getSrc().getType());
1859-
auto srcRegTy =
1860-
RankedTensorType::get(srcMemTy.getShape(), srcMemTy.getElementType(),
1861-
info->tensorType.getEncoding());
1859+
auto dstMemTy = cast<ttg::MemDescType>(op.getDst().getType());
1860+
auto srcEncoding =
1861+
scratch->getScratchEncoding(rewriter, op.getDst(), dstMemTy);
1862+
auto srcRegTy = RankedTensorType::get(
1863+
srcMemTy.getShape(), srcMemTy.getElementType(), srcEncoding);
18621864
Value srcReg =
18631865
ttg::LocalLoadOp::create(rewriter, loc, srcRegTy, op.getSrc(), Value())
18641866
.getResult();

python/test/gluon/test_fpsan.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
TensorMemoryScalesLayout,
1515
allocate_tensor_memory,
1616
mbarrier,
17+
tcgen05_commit,
18+
tcgen05_copy,
1719
tcgen05_mma,
1820
tcgen05_mma_scaled,
1921
)
@@ -1723,6 +1725,71 @@ def kernel(x_ptr, out_ptr):
17231725
_assert_payload_equal(out, exp_bits)
17241726

17251727

1728+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
1729+
def test_tmem_copy_scales_in_warp_specialize_partition(device, fresh_knobs):
1730+
_require_cuda_backend(device)
1731+
1732+
smem_h = 64
1733+
smem_w = 16
1734+
SMEM_H = gl.constexpr(smem_h)
1735+
SMEM_W = gl.constexpr(smem_w)
1736+
1737+
fresh_knobs.compilation.instrumentation_mode = "fpsan"
1738+
1739+
@gluon.jit
1740+
def copy_partition(smem, tmem, bar):
1741+
tcgen05_copy(smem, tmem)
1742+
tcgen05_commit(bar)
1743+
1744+
@gluon.jit
1745+
def default_partition():
1746+
pass
1747+
1748+
@gluon.jit
1749+
def kernel(in_ptr, out_ptr):
1750+
blocked: gl.constexpr = gl.BlockedLayout([1, 4], [32, 1], [gl.num_warps(), 1], [1, 0])
1751+
in_ptrs = (in_ptr + gl.arange(0, SMEM_H)[:, None] * SMEM_W + gl.arange(0, SMEM_W)[None, :])
1752+
value = gl.load(gl.set_auto_layout(in_ptrs, blocked))
1753+
1754+
smem_layout: gl.constexpr = gl.SharedLinearLayout(offset_bases=[
1755+
[0, 1],
1756+
[0, 2],
1757+
[32, 0],
1758+
[0, 4],
1759+
[1, 0],
1760+
[2, 0],
1761+
[4, 0],
1762+
[8, 0],
1763+
[16, 0],
1764+
[0, 8],
1765+
])
1766+
smem = gl.allocate_shared_memory(gl.int8, (SMEM_H, SMEM_W), layout=smem_layout)
1767+
smem.store(value)
1768+
1769+
tmem_layout: gl.constexpr = TensorMemoryScalesLayout()
1770+
tmem = allocate_tensor_memory(gl.int8, (SMEM_H, SMEM_W), layout=tmem_layout)
1771+
bar = gl.allocate_shared_memory(gl.int64, [1], gl.constexpr(mbarrier.MBarrierLayout()))
1772+
mbarrier.init(bar, count=1)
1773+
1774+
gl.warp_specialize(
1775+
[
1776+
(default_partition, ()),
1777+
(copy_partition, (smem, tmem, bar)),
1778+
],
1779+
[1],
1780+
[32],
1781+
)
1782+
1783+
mbarrier.wait(bar, phase=0)
1784+
mbarrier.invalidate(bar)
1785+
gl.store(out_ptr, 1)
1786+
1787+
x = torch.randint(size=(smem_h, smem_w), low=-100, high=100, dtype=torch.int8, device=device)
1788+
out = torch.empty((), device=device, dtype=torch.int32)
1789+
kernel[(1, )](x, out, num_warps=4)
1790+
torch.testing.assert_close(out, torch.ones_like(out))
1791+
1792+
17261793
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
17271794
def test_tmem_store_in_warp_specialize_partition_visible_to_parent(device, fresh_knobs):
17281795
_require_cuda_backend(device)

0 commit comments

Comments
 (0)