Skip to content

Commit b814e57

Browse files
authored
[MultiCTA] Implement broadcast semantics in mbarrier.arrive (#9475)
This was missing and it is necessary for multiCTA warp-specialised kernels. We also fix `mbarrier.init` to rescale the number of arrivals on the leader CTA to follow this pattern. In `mbarrier.expect` we emit `mbarrier.expect` from the leader CTA (which actas as an `mbarrier.arrive bar, 1`), and we emit `mbarrier.arrive bar, 1` from the non leader CTAs to go with the semantics above. This has as a nice corollary now `expect` also has release semantics, which is nice. We hope this should not be a perf issue in real kernels really. Finally, we implement these semantics in two helper functions and use them across the codebase.
1 parent 34bdc40 commit b814e57

5 files changed

Lines changed: 183 additions & 57 deletions

File tree

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1313

1414
// -----
1515

16+
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
17+
#smem = #ttg.shared_memory
18+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
19+
// CHECK-LABEL: init_barrier_cluster_broadcast
20+
tt.func @init_barrier_cluster_broadcast(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
21+
// CHECK: nvg.cluster_id
22+
// CHECK: @$0 mbarrier.init.shared::cta.b64 [$1], 2;
23+
ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem>
24+
tt.return
25+
}
26+
}
27+
28+
// -----
29+
1630
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
1731
#smem = #ttg.shared_memory
1832
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
@@ -45,30 +59,55 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
4559

4660
// CHECK-LABEL: arrive_barrier
4761
tt.func @arrive_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
62+
// CHECK-NEXT: [[BASE:%.*]] = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr<3>, i32)>
63+
// CHECK-NEXT: llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<3>, i32)>
64+
// CHECK-NEXT: nvvm.barrier0
4865
// CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x
4966
// CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32)
5067
// CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]]
5168
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
5269
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[RTID]], [[C0]]
53-
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[IS_ZERO]], %arg0
70+
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[IS_ZERO]], [[BASE]]
5471
ttng.arrive_barrier %alloc, 2 : !ttg.memdesc<1xi64, #shared0, #smem>
5572
tt.return
5673
}
5774

5875
// CHECK-LABEL: arrive_barrier_pred
5976
tt.func @arrive_barrier_pred(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) {
77+
// CHECK-NEXT: [[BASE:%.*]] = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr<3>, i32)>
78+
// CHECK-NEXT: llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<3>, i32)>
79+
// CHECK-NEXT: nvvm.barrier0
6080
// CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x
6181
// CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32)
6282
// CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]]
6383
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
6484
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[RTID]], [[C0]]
6585
// CHECK-NEXT: [[PRED:%.*]] = llvm.and [[IS_ZERO]], %arg1
66-
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[PRED]], %arg0
86+
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[PRED]], [[BASE]]
6787
ttng.arrive_barrier %alloc, 2, %pred : !ttg.memdesc<1xi64, #shared0, #smem>
6888
tt.return
6989
}
7090
}
7191

92+
// -----
93+
94+
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
95+
#smem = #ttg.shared_memory
96+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
97+
// CHECK-LABEL: arrive_barrier_cluster_broadcast
98+
tt.func @arrive_barrier_cluster_broadcast(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
99+
// CHECK: nvvm.barrier0
100+
// CHECK: nvg.cluster_id
101+
// CHECK: llvm.ptrtoint
102+
// CHECK: llvm.and
103+
// CHECK: llvm.inttoptr
104+
// CHECK: mbarrier.arrive.shared::cluster.b64
105+
// CHECK-NOT: mbarrier.arrive.shared::cta.b64
106+
ttng.arrive_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem>
107+
tt.return
108+
}
109+
}
110+
72111

73112
// -----
74113

@@ -218,6 +257,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
218257

219258
// -----
220259

260+
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
261+
#smem = #ttg.shared_memory
262+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
263+
// CHECK-LABEL: expect_barrier_cluster_broadcast
264+
// CHECK: nvg.cluster_id
265+
// CHECK: llvm.ptrtoint
266+
// CHECK: llvm.and
267+
// CHECK: llvm.inttoptr
268+
// CHECK: @$0 mbarrier.arrive.expect_tx.shared::cta.b64 _, [$1], 32768;
269+
// CHECK: @$0 mbarrier.arrive.shared::cluster.b64 _, [$1], 1;
270+
tt.func @expect_barrier_cluster_broadcast(%barrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>, %pred: i1) {
271+
ttng.barrier_expect %barrier, 16384, %pred : !ttg.memdesc<1xi64, #shared0, #smem, mutable>
272+
tt.return
273+
}
274+
}
275+
276+
// -----
277+
221278
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
222279
// CHECK-LABEL: byval_tma_desc
223280
// CHECK: llvm.align = 64

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,28 @@ struct InitBarrierOpConversion
111111
ConversionPatternRewriter &rewriter) const override {
112112
Location loc = op->getLoc();
113113
auto b = TritonLLVMOpBuilder(loc, rewriter);
114+
auto barrierTy = op.getAlloc().getType();
114115
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
115116
loc, adaptor.getAlloc(),
116-
typeConverter->convertType(op.getAlloc().getType().getElementType()),
117-
rewriter);
117+
typeConverter->convertType(barrierTy.getElementType()), rewriter);
118118

119119
// We use an elect predicate to tell ptxas that the operation is uniform,
120120
// which results in better codegen.
121121
Value pred = getElectWarp0OrThread0(*targetInfo, b);
122+
123+
if (auto leaderPred =
124+
LLVM::NVIDIA::getLeaderCTAPredicate(loc, rewriter, barrierTy))
125+
pred = b.and_(pred, *leaderPred);
126+
127+
auto numCTAs = triton::gpu::lookupNumCTAs(op);
128+
auto initCount = op.getCount();
129+
// The lead barrier accounts for all arrives from CTAs that broadcast into
130+
// the same barrier.
131+
initCount *= numCTAs / barrierTy.getNumElements();
132+
122133
::mlir::triton::PTXBuilder ptxBuilder;
123134
const std::string ptx = "@$0 mbarrier.init.shared::cta.b64 [$1], " +
124-
std::to_string(op.getCount()) + ";";
135+
std::to_string(initCount) + ";";
125136
auto &barSyncOp = *ptxBuilder.create(ptx);
126137
barSyncOp({ptxBuilder.newOperand(pred, "b"),
127138
ptxBuilder.newOperand(smemObj.getBase(), "r")},
@@ -188,31 +199,40 @@ struct BarrierExpectConversion
188199
auto expectedBytes = op.getSize() * (numCTAs / barrierTy.getNumElements());
189200

190201
auto id = getThreadId(rewriter, loc);
191-
Value pred = b.icmp_eq(id, b.i32_val(0));
192-
pred = b.and_(pred, adaptor.getPred());
193-
194-
auto kBlock = StringAttr::get(op->getContext(), "block");
195-
auto maskCGABroadcast =
196-
toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock);
197-
if (maskCGABroadcast) {
198-
// If several CTAs cast to the same barrier, as when we do a TMA into a
199-
// tcgen05.mma 2CTA, we just register the expect in the lead barrier, as
200-
// it is the only one that will receive the mbarrier signals
201-
auto ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc);
202-
auto ctaIdInGroup = b.and_(ctaId, b.i32_val(maskCGABroadcast));
203-
pred = b.and_(pred, b.icmp_eq(ctaIdInGroup, b.i32_val(0)));
204-
}
205-
206-
::mlir::triton::PTXBuilder ptxBuilder;
207-
const std::string ptx =
202+
Value basePred = b.icmp_eq(id, b.i32_val(0));
203+
basePred = b.and_(basePred, adaptor.getPred());
204+
auto leaderCTAPred =
205+
LLVM::NVIDIA::getLeaderCTAPredicate(loc, rewriter, barrierTy);
206+
bool crossCluster = leaderCTAPred.has_value();
207+
Value leaderPred =
208+
leaderCTAPred ? b.and_(basePred, *leaderCTAPred) : basePred;
209+
Value leaderBarrierPtr = LLVM::NVIDIA::getLeaderAddress(
210+
loc, rewriter, smemObj.getBase(), barrierTy);
211+
212+
::mlir::triton::PTXBuilder expectPtxBuilder;
213+
const std::string expectPtx =
208214
"@$0 mbarrier.arrive.expect_tx.shared::cta.b64 _, [$1], " +
209215
std::to_string(expectedBytes) + ";";
210-
auto &barSyncOp = *ptxBuilder.create(ptx);
211-
barSyncOp({ptxBuilder.newOperand(pred, "b"),
212-
ptxBuilder.newOperand(smemObj.getBase(), "r")},
213-
/*onlyAttachMLIRArgs=*/true);
216+
auto &expectOp = *expectPtxBuilder.create(expectPtx);
217+
expectOp({expectPtxBuilder.newOperand(leaderPred, "b"),
218+
expectPtxBuilder.newOperand(leaderBarrierPtr, "r")},
219+
/*onlyAttachMLIRArgs=*/true);
214220
auto voidTy = void_ty(op->getContext());
215-
ptxBuilder.launch(rewriter, loc, voidTy);
221+
expectPtxBuilder.launch(rewriter, loc, voidTy);
222+
223+
if (crossCluster) {
224+
// Non-leader CTAs still contribute one arrival to the lead CTA barrier.
225+
auto nonLeaderPred = b.and_(basePred, b.xor_(leaderPred, b.true_val()));
226+
::mlir::triton::PTXBuilder arrivePtxBuilder;
227+
const std::string arrivePtx =
228+
"@$0 mbarrier.arrive.shared::cluster.b64 _, [$1], 1;";
229+
auto &arriveOp = *arrivePtxBuilder.create(arrivePtx);
230+
arriveOp({arrivePtxBuilder.newOperand(nonLeaderPred, "b"),
231+
arrivePtxBuilder.newOperand(leaderBarrierPtr, "r")},
232+
/*onlyAttachMLIRArgs=*/true);
233+
arrivePtxBuilder.launch(rewriter, loc, voidTy);
234+
}
235+
216236
rewriter.eraseOp(op);
217237
return success();
218238
}
@@ -238,19 +258,9 @@ struct WaitBarrierOpConversion
238258
auto loc = op.getLoc();
239259
auto b = TritonLLVMOpBuilder(loc, rewriter);
240260
auto pred = adaptor.getPred();
241-
242-
auto kBlock = StringAttr::get(ctx, "block");
243-
auto maskCGABroadcast =
244-
toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock);
245-
if (maskCGABroadcast) {
246-
// If several CTAs cast to the same barrier, as when we do a TMA into a
247-
// tcgen05.mma 2CTA, we send all the signals to the lead CTA, so even if
248-
// this barrier is waiting for zero bytes, no one will arrive on it. As
249-
// such, we predicate it out
250-
auto ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc);
251-
auto ctaIdInGroup = b.and_(ctaId, b.i32_val(maskCGABroadcast));
252-
pred = b.and_(pred, b.icmp_eq(ctaIdInGroup, b.i32_val(0)));
253-
}
261+
if (auto leaderPred =
262+
LLVM::NVIDIA::getLeaderCTAPredicate(loc, rewriter, barrierTy))
263+
pred = b.and_(pred, *leaderPred);
254264

255265
bool predicated = pred && !matchPattern(pred, m_NonZero());
256266
std::string ptx;
@@ -323,29 +333,48 @@ struct ArriveBarrierOpConversion
323333
LogicalResult
324334
matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor,
325335
ConversionPatternRewriter &rewriter) const override {
336+
auto loc = op.getLoc();
337+
auto b = TritonLLVMOpBuilder(loc, rewriter);
338+
auto barrierTy = op.getAlloc().getType();
339+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
340+
loc, adaptor.getAlloc(),
341+
typeConverter->convertType(barrierTy.getElementType()), rewriter);
342+
343+
// Arrive has block-level semantics, so we must synchronize
344+
// Technically, this should be MemBar's job but it can include TMEM
345+
// accesses which doesn't have a MemBar equivalent :/
346+
ttg::BarrierOp::create(rewriter, loc, ttg::AddrSpace::Local);
347+
348+
Value id = getThreadId(rewriter, loc);
349+
Value pred = b.icmp_eq(id, b.i32_val(0));
350+
if (op.getPred())
351+
pred = b.and_(pred, adaptor.getPred());
352+
353+
bool isCrossCluster =
354+
LLVM::NVIDIA::getLeaderCTAPredicate(loc, rewriter, barrierTy)
355+
.has_value();
356+
357+
Value barrierPtr = LLVM::NVIDIA::getLeaderAddress(
358+
loc, rewriter, smemObj.getBase(), barrierTy);
326359
// TODO: Add phase result as needed.
327360
std::stringstream ptxAsm;
328-
ptxAsm << "@$0 mbarrier.arrive.shared::cta.b64 _, [$1]";
361+
ptxAsm << "@$0 mbarrier.arrive."
362+
<< (isCrossCluster ? "shared::cluster" : "shared::cta")
363+
<< ".b64 _, [$1]";
329364
if (op.getCount() > 1) {
330365
ptxAsm << ", " << op.getCount();
331366
}
332367
ptxAsm << ";";
333368

334-
TritonLLVMOpBuilder b(op.getLoc(), rewriter);
335-
Value id = getThreadId(rewriter, op.getLoc());
336-
Value pred = b.icmp_eq(id, b.i32_val(0));
337-
if (op.getPred())
338-
pred = b.and_(pred, adaptor.getPred());
339-
340369
PTXBuilder ptxBuilder;
341370
SmallVector<PTXBuilder::Operand *, 2> operands = {
342371
ptxBuilder.newOperand(pred, "b"),
343-
ptxBuilder.newOperand(adaptor.getAlloc(), "r")};
372+
ptxBuilder.newOperand(barrierPtr, "r")};
344373

345374
auto arriveOp = *ptxBuilder.create(ptxAsm.str());
346375
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
347376
auto voidTy = void_ty(getContext());
348-
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
377+
ptxBuilder.launch(rewriter, loc, voidTy);
349378

350379
rewriter.eraseOp(op);
351380
return success();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,14 +1379,8 @@ struct AsyncTMACopyGlobalToLocalOpConversion
13791379
// out)
13801380
bool clusterBarrier = barrierMask & ~maskCGABroadcast;
13811381
if (clusterBarrier) {
1382-
// This part is to support TMA into tcgen05.mma 2CTA mostly, i.e.,
1383-
// barrierMask == 1
1384-
// Mask with ones on the bits where the CTA broadcasts.
1385-
// This is a trick from cutlass to implement a faster `mapa`.
1386-
uint32_t fullMask = ~(barrierMask << 24);
1387-
Value barrierInt = b.ptrtoint(i32_ty, barrierPtr);
1388-
barrierInt = b.and_(barrierInt, b.i32_val(fullMask));
1389-
barrierPtr = b.inttoptr(barrierPtr.getType(), barrierInt);
1382+
barrierPtr =
1383+
LLVM::NVIDIA::getLeaderAddress(loc, rewriter, barrierPtr, barrierTy);
13901384
}
13911385

13921386
// Don't set cta_group::1 as it doesn't exist pre-Blackwell

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ Value createTMAMulticastMask(Location loc, ConversionPatternRewriter &rewriter,
155155
return b.shl(b.i32_val(pattern), base);
156156
}
157157

158+
static uint32_t getCGABroadcastMask(mlir::triton::gpu::MemDescType barrierTy) {
159+
auto kBlock = StringAttr::get(barrierTy.getContext(), "block");
160+
return toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock);
161+
}
162+
163+
std::optional<Value>
164+
getLeaderCTAPredicate(Location loc, ConversionPatternRewriter &rewriter,
165+
mlir::triton::gpu::MemDescType barrierTy) {
166+
auto b = TritonLLVMOpBuilder(loc, rewriter);
167+
uint32_t maskCGABroadcast = getCGABroadcastMask(barrierTy);
168+
if (!maskCGABroadcast)
169+
return std::nullopt;
170+
171+
Value ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc);
172+
Value ctaIdInGroup = b.and_(ctaId, b.i32_val(maskCGABroadcast));
173+
return std::optional<Value>(b.icmp_eq(ctaIdInGroup, b.i32_val(0)));
174+
}
175+
176+
Value getLeaderAddress(Location loc, ConversionPatternRewriter &rewriter,
177+
Value barrierPtr,
178+
mlir::triton::gpu::MemDescType barrierTy) {
179+
uint32_t barrierMask = getCGABroadcastMask(barrierTy);
180+
if (!barrierMask)
181+
return barrierPtr;
182+
183+
// Trick from cutlass to implement a faster `mapa` via a single and
184+
auto b = TritonLLVMOpBuilder(loc, rewriter);
185+
uint32_t fullMask = ~(barrierMask << 24);
186+
Value barrierInt = b.ptrtoint(i32_ty, barrierPtr);
187+
barrierInt = b.and_(barrierInt, b.i32_val(fullMask));
188+
return b.inttoptr(barrierPtr.getType(), barrierInt);
189+
}
190+
158191
LogicalResult lowerLdStMatrix(
159192
Location loc, LinearLayout cvt, bool transpose,
160193
SmallVector<Value> &vals, // Input for stmatrix, output for ldmatrix

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_UTILITY_H
33

44
#include <cstdint>
5+
#include <optional>
56

67
#include "nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
78

@@ -63,6 +64,18 @@ LogicalResult lowerLdStMatrix(
6364
// group
6465
Value createTMAMulticastMask(Location loc, ConversionPatternRewriter &rewriter,
6566
uint16_t broadcastBits);
67+
68+
// Returns the lead CTA predicate for this barrier layout when lowering through
69+
// cluster scope. Returns std::nullopt for CTA-local lowering.
70+
std::optional<Value>
71+
getLeaderCTAPredicate(Location loc, ConversionPatternRewriter &rewriter,
72+
mlir::triton::gpu::MemDescType barrierTy);
73+
74+
// Returns the lead CTA barrier address for this layout. If there is no
75+
// cross-cluster lowering, returns barrierPtr unchanged.
76+
Value getLeaderAddress(Location loc, ConversionPatternRewriter &rewriter,
77+
Value barrierPtr,
78+
mlir::triton::gpu::MemDescType barrierTy);
6679
} // namespace NVIDIA
6780
} // namespace LLVM
6881

0 commit comments

Comments
 (0)