Skip to content

Commit e9352e2

Browse files
authored
[BACKEND] Model async TMA variants in ConSan (#10015)
We follow TMA load / store closely.
1 parent f62f95b commit e9352e2

File tree

4 files changed

+91
-27
lines changed

4 files changed

+91
-27
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def TMALoadLikeOpInterface : OpInterface<"TMALoadLikeOpInterface", [TMAOpInterfa
4242
/*retType=*/"::mlir::Value",
4343
/*methodName=*/"getPred",
4444
/*args=*/(ins)>,
45+
InterfaceMethod<
46+
/*desc=*/"Return true if this load uses multicast",
47+
/*retType=*/"bool",
48+
/*methodName=*/"getMulticast",
49+
/*args=*/(ins)>,
4550
];
4651
}
4752

lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ SmallVector<uint16_t> getTensorCoreBarrierBroadcastMasks(Operation *op) {
254254
Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op);
255255

256256
Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
257-
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
258-
if (copyOp.getMulticast())
259-
return getMulticastRecipientCTAs(b, copyOp.getResult());
257+
if (auto tmaLoad = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
258+
if (tmaLoad.getMulticast())
259+
return getMulticastRecipientCTAs(b, tmaLoad.getResult());
260260
return currentCTAMask(b);
261261
}
262262
if (isTensorCoreOp(op))
@@ -272,14 +272,12 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
272272
return getLeaderCTA(b, arriveOp.getAlloc());
273273
if (auto arriveOp = dyn_cast<ttng::AsyncCopyMbarrierArriveOp>(op))
274274
return getLeaderCTA(b, arriveOp.getBarrier());
275-
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
276-
if (copyOp.getMulticast())
277-
return getMulticastBarrierRecipientCTAs(b, copyOp.getResult(),
278-
copyOp.getBarrier());
279-
return getLeaderCTA(b, copyOp.getBarrier());
280-
}
281-
if (auto tmaLoad = dyn_cast<ttng::TMALoadLikeOpInterface>(op))
275+
if (auto tmaLoad = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
276+
if (tmaLoad.getMulticast())
277+
return getMulticastBarrierRecipientCTAs(b, tmaLoad.getResult(),
278+
tmaLoad.getBarrier());
282279
return getLeaderCTA(b, tmaLoad.getBarrier());
280+
}
283281

284282
if (isTensorCoreOp(op))
285283
return getRecipientCTAsForBroadcastMasks(

lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ Value getLeaderCTAPredicate(ImplicitLocOpBuilder &b, uint32_t broadcastMask) {
2828
arith::ConstantIntOp::create(b, 0, 32));
2929
}
3030

31+
uint32_t getBlockBroadcastMask(Type type) {
32+
auto memDescTy = cast<ttg::MemDescType>(type);
33+
auto kBlock = StringAttr::get(type.getContext(), "block");
34+
return toLinearLayout(memDescTy).getFreeVariableMasks().lookup(kBlock);
35+
}
36+
3137
} // namespace
3238

3339
class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
@@ -89,13 +95,12 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
8995
mask = getBarrierMask(waitOp.getAlloc());
9096
if (auto invalOp = dyn_cast<ttng::InvalBarrierOp>(op))
9197
mask = getBarrierMask(invalOp.getAlloc());
92-
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
93-
if (copyOp.getMulticast()) {
94-
auto dstTy = cast<ttg::MemDescType>(copyOp.getResult().getType());
95-
auto kBlock = StringAttr::get(op->getContext(), "block");
96-
mask = toLinearLayout(dstTy).getFreeVariableMasks().lookup(kBlock);
97-
}
98+
if (auto loadOp = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
99+
if (loadOp.getMulticast())
100+
mask = getBlockBroadcastMask(loadOp.getResult().getType());
98101
}
102+
if (auto storeOp = dyn_cast<ttng::TMAStoreLikeOpInterface>(op))
103+
mask = getBlockBroadcastMask(storeOp.getSrc().getType());
99104

100105
// In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair
101106
// issues the op.
@@ -205,16 +210,12 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
205210
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
206211
info->pred = loadOp.getPred();
207212
int txCount = tti::getMemDescLength(loadOp.getResult());
208-
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op);
209-
copyOp && copyOp.getMulticast()) {
210-
auto resultTy = cast<ttg::MemDescType>(loadOp.getResult().getType());
211-
auto barrierTy = cast<ttg::MemDescType>(loadOp.getBarrier().getType());
212-
auto kBlock = StringAttr::get(op->getContext(), "block");
213-
uint16_t resultMask =
214-
toLinearLayout(resultTy).getFreeVariableMasks().lookup(kBlock);
215-
uint16_t barrierMask =
216-
toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock);
217-
uint16_t collapsedMask = resultMask & barrierMask;
213+
if (loadOp.getMulticast()) {
214+
uint32_t resultMask =
215+
getBlockBroadcastMask(loadOp.getResult().getType());
216+
uint32_t barrierMask =
217+
getBlockBroadcastMask(loadOp.getBarrier().getType());
218+
uint32_t collapsedMask = resultMask & barrierMask;
218219
for (; collapsedMask; collapsedMask &= collapsedMask - 1)
219220
txCount *= 2;
220221
}

test/TritonGPU/consan.mlir

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,40 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
293293

294294
// -----
295295

296+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[0, 0]]}>
297+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
298+
#smem = #ttg.shared_memory
299+
#offset_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 0]]}>
300+
#offsets = #ttg.slice<{dim = 0, parent = #offset_parent}>
301+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, "ttng.two-ctas" = true, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
302+
// CHECK-LABEL: @multicast_gather_two_cta_tx_count
303+
tt.func public @multicast_gather_two_cta_tx_count(%desc: !tt.tensordesc<1x32xf32, #shared>) {
304+
%true = arith.constant true
305+
%c0_i32 = arith.constant 0 : i32
306+
%c0 = arith.constant 0 : index
307+
%c1 = arith.constant 1 : index
308+
%c2 = arith.constant 2 : index
309+
%x_offsets = arith.constant dense<0> : tensor<32xi32, #offsets>
310+
%bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable>
311+
%result = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
312+
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
313+
// CHECK: scf.for
314+
scf.for %i = %c0 to %c2 step %c1 {
315+
// CHECK: arith.constant 8192 : i64
316+
// CHECK: tt.call @__triton_consan_verify_barrier_arrive
317+
// CHECK: ttng.barrier_expect
318+
ttng.barrier_expect %bar, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
319+
// CHECK: arith.constant -8192 : i64
320+
// CHECK: tt.call @__triton_consan_verify_barrier_arrive
321+
// CHECK: ttng.async_tma_gather
322+
ttng.async_tma_gather %desc[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32, #offsets>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1
323+
}
324+
tt.return
325+
}
326+
}
327+
328+
// -----
329+
296330
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
297331
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
298332
#smem = #ttg.shared_memory
@@ -479,9 +513,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
479513
ttng.warp_group_dot %shmem, %shmem, %acc : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #mma>
480514
// CHECK: ttng.warp_group_dot
481515

516+
// CHECK: tt.call @__triton_consan_verify_write_visibility
517+
// CHECK: tt.call @__triton_consan_check_outstanding_commits
518+
// CHECK: tt.call @__triton_consan_stage_access_for_commit
519+
// CHECK: tt.call @__triton_consan_commit_accesses
520+
ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
521+
tt.return
522+
}
523+
}
524+
525+
// -----
526+
527+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
528+
#smem = #ttg.shared_memory
529+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
530+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
531+
532+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
533+
// CHECK-LABEL: @async_tma_reduce
534+
tt.func public @async_tma_reduce(%arg0: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr<f16>, #blocked>, %acc: tensor<128x128xf16, #mma>) {
535+
%c0_i32 = arith.constant 0 : i32
536+
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
537+
%shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
538+
ttg.async_copy_global_to_local %ptr, %shmem : tensor<128x128x!tt.ptr<f16>, #blocked> -> <128x128xf16, #shared, #smem, mutable>
539+
482540
// CHECK: tt.call @__triton_consan_verify_write_visibility
483541
// CHECK: tt.call @__triton_consan_check_outstanding_commits
484-
ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
542+
// CHECK: tt.call @__triton_consan_stage_access_for_commit
543+
// CHECK: tt.call @__triton_consan_commit_accesses
544+
ttng.async_tma_reduce add, %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
485545
tt.return
486546
}
487547
}

0 commit comments

Comments
 (0)