Skip to content

Commit 76fbef9

Browse files
authored
[BACKEND][EZ] Tighten TMA multicta verifier (#9941)
As per the PTX docs, TMAs have a very specific behaviour when executed in a 2CTA kernel: >.cta_group::1 : The mbarrier signal is also multicasted to the same offset as mbar in the shared memory of the destination CTA. .cta_group::2 : The mbarrier signal is multicasted either to all the odd numbered CTAs or the even numbered CTAs within the corresponding CTA-Pair. For each destination CTA specified in the ctaMask, the mbarrier signal is sent either to the destination CTA or its peer-CTA based on CTAs %cluster_ctarank parity of shared memory where the mbarrier object mbar resides. As such, we require these CTA layouts in TMA barriers.
1 parent 89d154d commit 76fbef9

6 files changed

Lines changed: 146 additions & 128 deletions

File tree

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,62 @@ LogicalResult ClusterBarrierOp::verify() {
290290
}
291291

292292
// -- TMA operation verifiers --
293+
static std::string formatCGALayout(CGAEncodingAttr cgaLayout) {
294+
std::string str;
295+
llvm::raw_string_ostream os(str);
296+
auto kBlock = StringAttr::get(cgaLayout.getContext(), "block");
297+
os << "[";
298+
llvm::interleaveComma(cgaLayout.getLinearLayout().getBases().lookup(kBlock),
299+
os, [&](const auto &basis) {
300+
os << "[";
301+
llvm::interleaveComma(basis, os);
302+
os << "]";
303+
});
304+
os << "]";
305+
return os.str();
306+
}
307+
308+
static LogicalResult verifyBarrierCGALayout(Operation *op, Value barrier,
309+
CGAEncodingAttr expectedCGALayout,
310+
StringRef barrierName) {
311+
auto barrierTy = cast<MemDescType>(barrier.getType());
312+
auto actualCGALayout = getCGALayout(barrierTy.getEncoding());
313+
if (actualCGALayout != expectedCGALayout)
314+
return op->emitOpError() << barrierName << " cga_layout must be "
315+
<< formatCGALayout(expectedCGALayout) << ", got "
316+
<< formatCGALayout(actualCGALayout);
317+
return success();
318+
}
319+
320+
static LogicalResult verifyCompletionBarrierLayout(Operation *op,
321+
Value barrier) {
322+
auto expectedCGALayout =
323+
CGAEncodingAttr::get1DLayout(op->getContext(), gpu::lookupNumCTAs(op));
324+
return verifyBarrierCGALayout(op, barrier, expectedCGALayout,
325+
"completion barrier");
326+
}
327+
328+
static LogicalResult verifyTMABarrierLayout(Operation *op, Value barrier) {
329+
auto twoCTAsAttr =
330+
op->getParentOfType<ModuleOp>()->getAttrOfType<BoolAttr>(AttrTwoCTAsName);
331+
if (!twoCTAsAttr)
332+
return success();
333+
334+
auto ctx = op->getContext();
335+
int numCTAs = gpu::lookupNumCTAs(op);
336+
CGAEncodingAttr expectedCGALayout;
337+
if (twoCTAsAttr.getValue()) {
338+
auto kBlock = StringAttr::get(ctx, "block");
339+
auto dim = standardOutDimNames(ctx, /*rank=*/1)[0];
340+
auto layout = LinearLayout::zeros1D(2, kBlock, dim) *
341+
LinearLayout::identity1D(numCTAs / 2, kBlock, dim);
342+
expectedCGALayout = CGAEncodingAttr::get(ctx, std::move(layout));
343+
} else {
344+
expectedCGALayout = CGAEncodingAttr::get1DLayout(ctx, numCTAs);
345+
}
346+
return verifyBarrierCGALayout(op, barrier, expectedCGALayout, "TMA barrier");
347+
}
348+
293349
static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc,
294350
Attribute enc) {
295351
auto nvmma = dyn_cast<NVMMASharedEncodingAttr>(enc);
@@ -318,6 +374,8 @@ static LogicalResult verifyAsyncTMALoadOp(Operation *op,
318374
MemDescType resultType) {
319375
if (failed(verifyBarrierType(op, barrier.getType())))
320376
return failure();
377+
if (failed(verifyTMABarrierLayout(op, barrier)))
378+
return failure();
321379
if (!resultType.getMutableMemory())
322380
return op->emitOpError("cannot store into immutable memory");
323381
if (failed(verifyTMAEncoding(op, desc, resultType.getEncoding())))
@@ -599,34 +657,6 @@ static LogicalResult verifyMMADType(Operation *op, Type a, Type b, Type d) {
599657
return success();
600658
}
601659

602-
static std::string formatCGALayout(CGAEncodingAttr cgaLayout) {
603-
std::string str;
604-
llvm::raw_string_ostream os(str);
605-
auto kBlock = StringAttr::get(cgaLayout.getContext(), "block");
606-
os << "[";
607-
llvm::interleaveComma(cgaLayout.getLinearLayout().getBases().lookup(kBlock),
608-
os, [&](const auto &basis) {
609-
os << "[";
610-
llvm::interleaveComma(basis, os);
611-
os << "]";
612-
});
613-
os << "]";
614-
return os.str();
615-
}
616-
617-
static LogicalResult verifyCompletionBarrierLayout(Operation *op,
618-
Value barrier) {
619-
auto barrierTy = cast<MemDescType>(barrier.getType());
620-
auto expectedCGALayout =
621-
CGAEncodingAttr::get1DLayout(op->getContext(), gpu::lookupNumCTAs(op));
622-
auto actualCGALayout = getCGALayout(barrierTy.getEncoding());
623-
if (actualCGALayout != expectedCGALayout)
624-
return op->emitOpError("completion barrier cga_layout must be ")
625-
<< formatCGALayout(expectedCGALayout) << ", got "
626-
<< formatCGALayout(actualCGALayout);
627-
return success();
628-
}
629-
630660
LogicalResult TCGen5MMAOp::verify() {
631661
if (!getIsAsync() && !getBarriers().empty()) {
632662
return emitOpError("The op is synchronous but a barrier is present.");

python/test/gluon/test_consan.py

Lines changed: 24 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,10 @@ def test_consan_uses_profile_scratch(device, fresh_knobs, num_ctas):
158158

159159
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
160160
@pytest.mark.parametrize("FAILURE", [True, False])
161-
@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True])
162-
def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas):
163-
if TWO_CTA_BARRIER and num_ctas == 1:
164-
pytest.skip("Need at least 2 CTAs for a two-CTA barrier")
161+
def test_async_tma_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas):
165162
if run_wrapper:
166-
result = run_in_process(test_async_tma_kernel, (FAILURE, TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas))
167-
if FAILURE or TWO_CTA_BARRIER:
163+
result = run_in_process(test_async_tma_kernel, (FAILURE, device, False, monkeypatch, num_ctas))
164+
if FAILURE:
168165
assert_expected_cuda_failure(result.exc)
169166
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
170167
else:
@@ -177,13 +174,13 @@ def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeyp
177174
knobs.refresh_knobs()
178175

179176
@gluon.jit
180-
def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.constexpr):
177+
def kernel(input_desc, out, FAILURE: ttgl.constexpr):
181178
block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas()
182179
cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2)
183180
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1],
184181
warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout)
185182
smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout)
186-
bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER)
183+
bar = mbarrier.allocate_mbarrier()
187184
mbarrier.init(bar, count=1)
188185
mbarrier.expect(bar, input_desc.nbytes_per_cta)
189186
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
@@ -203,19 +200,17 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const
203200
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2,
204201
cga_layout=default_cga_layout(num_ctas, 2))
205202
input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [block_m, XBLOCK.value], shared_layout)
206-
kernel[(1, )](input_desc, output, FAILURE=FAILURE, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas)
203+
kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas)
207204

208205

209206
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
210207
@pytest.mark.parametrize("FAILURE", [True, False])
211-
@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True])
212-
def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas):
208+
def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas):
213209
if num_ctas == 1:
214210
pytest.skip("Need at least 2 CTAs for multicast in this test")
215211
if run_wrapper:
216-
result = run_in_process(test_async_tma_multicast_kernel,
217-
(FAILURE, TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas))
218-
if FAILURE or TWO_CTA_BARRIER:
212+
result = run_in_process(test_async_tma_multicast_kernel, (FAILURE, device, False, monkeypatch, num_ctas))
213+
if FAILURE:
219214
assert_expected_cuda_failure(result.exc)
220215
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
221216
else:
@@ -228,12 +223,12 @@ def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrappe
228223
knobs.refresh_knobs()
229224

230225
@gluon.jit
231-
def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.constexpr):
226+
def kernel(input_desc, out, FAILURE: ttgl.constexpr):
232227
cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 2)
233228
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1],
234229
warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout)
235230
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
236-
bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER)
231+
bar = mbarrier.allocate_mbarrier()
237232
mbarrier.init(bar, count=1)
238233
mbarrier.expect(bar, input_desc.nbytes_per_cta)
239234
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
@@ -252,54 +247,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const
252247
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2,
253248
cga_layout=multicast_cga_layout(num_ctas, 2))
254249
input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK.value, XBLOCK.value], shared_layout)
255-
kernel[(1, )](input_desc, output, FAILURE=FAILURE, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas)
256-
257-
258-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
259-
@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True])
260-
def test_async_tma_multicast_kernel_two_cta_barrier(TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas):
261-
if num_ctas != 2:
262-
pytest.skip("This test covers a single 2-CTA multicast group")
263-
if run_wrapper:
264-
result = run_in_process(test_async_tma_multicast_kernel_two_cta_barrier,
265-
(TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas))
266-
if TWO_CTA_BARRIER:
267-
assert_expected_cuda_failure(result.exc)
268-
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
269-
else:
270-
assert result.exc is None
271-
assert result.driver_stderr_output == ""
272-
return
273-
274-
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
275-
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
276-
knobs.refresh_knobs()
277-
278-
@gluon.jit
279-
def kernel(input_desc, out, TWO_CTA_BARRIER: ttgl.constexpr):
280-
cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 2)
281-
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1],
282-
warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout)
283-
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
284-
bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER)
285-
mbarrier.init(bar, count=1)
286-
mbarrier.expect(bar, input_desc.nbytes_per_cta)
287-
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True)
288-
mbarrier.wait(bar, 0, deps=[smem])
289-
val = smem.load(blocked_layout)
290-
mbarrier.invalidate(bar)
291-
292-
out_m = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, blocked_layout))[:, None]
293-
out_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :]
294-
out_ptr = out + out_m * XBLOCK + out_n
295-
ttgl.store(out_ptr, val)
296-
297-
input = torch.randn((XBLOCK.value, XBLOCK.value), device=device, dtype=torch.float16)
298-
output = torch.empty((XBLOCK.value, XBLOCK.value), device=device, dtype=torch.float16)
299-
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2,
300-
cga_layout=multicast_cga_layout(num_ctas, 2))
301-
input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK.value, XBLOCK.value], shared_layout)
302-
kernel[(1, )](input_desc, output, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas)
250+
kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas)
303251

304252

305253
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
@@ -741,21 +689,24 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt
741689
ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16,
742690
cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, TWO_CTAS)),
743691
)
744-
bar = mbarrier.allocate_mbarrier(batch=2)
692+
mma_bar = mbarrier.allocate_mbarrier()
745693
acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout)
746-
mbarrier.init(bar.index(0), count=1)
747-
mbarrier.init(bar.index(1), count=1)
694+
mbarrier.init(mma_bar, count=1)
695+
if MEM_ACCESS_KIND == "tma_cp":
696+
tma_bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTAS)
697+
mbarrier.init(tma_bar, count=1)
748698

749699
blackwell.tcgen05_mma(smemA, smemB, acc)
750-
blackwell.tcgen05_commit(bar.index(0))
700+
blackwell.tcgen05_commit(mma_bar)
751701

752702
if not FAILURE:
753-
mbarrier.wait(bar.index(0), 0)
703+
mbarrier.wait(mma_bar, 0)
754704

755705
if MEM_ACCESS_KIND == "tma_cp":
756-
mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta)
757-
tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smemA)
758-
mbarrier.wait(bar.index(1), 0)
706+
mbarrier.expect(tma_bar, input_desc.nbytes_per_cta)
707+
tma.async_copy_global_to_shared(input_desc, [0, 0], tma_bar, smemA)
708+
mbarrier.wait(tma_bar, 0)
709+
mbarrier.invalidate(tma_bar)
759710
elif MEM_ACCESS_KIND == "local_store":
760711
smemA.store(ttgl.full([block_m, XBLOCK], 42, ttgl.float16, smem_a_blocked_layout))
761712
elif MEM_ACCESS_KIND == "tmem_load":
@@ -769,8 +720,7 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt
769720
elif MEM_ACCESS_KIND == "tmem_store":
770721
acc.store(ttgl.full([block_m, block_n], 42, ttgl.float32, acc_blocked_layout))
771722

772-
mbarrier.invalidate(bar.index(0))
773-
mbarrier.invalidate(bar.index(1))
723+
mbarrier.invalidate(mma_bar)
774724

775725
block_m = mma_block_m(num_ctas)
776726
block_n = mma_block_n(num_ctas)

python/tutorials/gluon/14-multicta.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,9 @@ def broadcast(b):
477477
# every CTA in the multicast group atomically, so the wait side does not need a
478478
# different API.
479479
#
480-
# The only new ingredient is the layout. The TMA destination must use a
481-
# broadcast `cga_layout`, so that both CTAs view the same shared-memory tile.
480+
# The TMA destination must use a broadcast `cga_layout`, so that both CTAs
481+
# receive the same shared-memory tile. The barrier stays a regular 1D TMA
482+
# barrier unless the kernel is in 2CTA mode.
482483
#
483484
# The example below keeps things intentionally simple: it multicasts one tile
484485
# into shared memory and then materializes that same tile back to global memory.
@@ -489,6 +490,7 @@ def tma_multicast_copy_kernel(in_desc, out_desc):
489490
gl.static_assert(gl.num_ctas() == 2)
490491

491492
smem = gl.allocate_shared_memory(in_desc.dtype, in_desc.block_shape, in_desc.layout)
493+
# This kernel is not in 2CTA mode, so the TMA barrier is per-CTA.
492494
bar = mbarrier.allocate_mbarrier()
493495
mbarrier.init(bar, count=1)
494496

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
167167
#shared0_cluster = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
168168
#shared1_cga = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CGALayout = [[0, 0]]}>
169169
#smem = #ttg.shared_memory
170-
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
170+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} {
171171
// CHECK-LABEL: tma_copy_barrier_mask_nonzero
172172
// Barrier pointer is modified when barrier mask != 0
173173
// CHECK: llvm.ptrtoint
174174
// CHECK: llvm.and
175175
// CHECK: llvm.inttoptr
176176
// TMA uses shared::cluster when barrier mask is non-zero
177-
// CHECK: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
177+
// CHECK: cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes
178178
// CHECK-NOT: cp.async.bulk.tensor.2d.shared::cta.global.mbarrier
179179
tt.func @tma_copy_barrier_mask_nonzero(%tma: !tt.tensordesc<128x128xf32, #shared1_cga>, %alloc: !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cluster, #smem>, %pred: i1) {
180180
ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc<128x128xf32, #shared1_cga>, !ttg.memdesc<1xi64, #shared0_cluster, #smem> -> !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>

0 commit comments

Comments
 (0)