@@ -111,17 +111,28 @@ struct InitBarrierOpConversion
111111 ConversionPatternRewriter &rewriter) const override {
112112 Location loc = op->getLoc ();
113113 auto b = TritonLLVMOpBuilder (loc, rewriter);
114+ auto barrierTy = op.getAlloc ().getType ();
114115 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
115116 loc, adaptor.getAlloc (),
116- typeConverter->convertType (op.getAlloc ().getType ().getElementType ()),
117- rewriter);
117+ typeConverter->convertType (barrierTy.getElementType ()), rewriter);
118118
119119 // We use an elect predicate to tell ptxas that the operation is uniform,
120120 // which results in better codegen.
121121 Value pred = getElectWarp0OrThread0 (*targetInfo, b);
122+
123+ if (auto leaderPred =
124+ LLVM::NVIDIA::getLeaderCTAPredicate (loc, rewriter, barrierTy))
125+ pred = b.and_ (pred, *leaderPred);
126+
127+ auto numCTAs = triton::gpu::lookupNumCTAs (op);
128+ auto initCount = op.getCount ();
129+ // The lead barrier accounts for all arrives from CTAs that broadcast into
130+ // the same barrier.
131+ initCount *= numCTAs / barrierTy.getNumElements ();
132+
122133 ::mlir::triton::PTXBuilder ptxBuilder;
123134 const std::string ptx = " @$0 mbarrier.init.shared::cta.b64 [$1], " +
124- std::to_string (op. getCount () ) + " ;" ;
135+ std::to_string (initCount ) + " ;" ;
125136 auto &barSyncOp = *ptxBuilder.create (ptx);
126137 barSyncOp ({ptxBuilder.newOperand (pred, " b" ),
127138 ptxBuilder.newOperand (smemObj.getBase (), " r" )},
@@ -188,31 +199,40 @@ struct BarrierExpectConversion
188199 auto expectedBytes = op.getSize () * (numCTAs / barrierTy.getNumElements ());
189200
190201 auto id = getThreadId (rewriter, loc);
191- Value pred = b.icmp_eq (id, b.i32_val (0 ));
192- pred = b.and_ (pred, adaptor.getPred ());
193-
194- auto kBlock = StringAttr::get (op->getContext (), " block" );
195- auto maskCGABroadcast =
196- toLinearLayout (barrierTy).getFreeVariableMasks ().lookup (kBlock );
197- if (maskCGABroadcast) {
198- // If several CTAs cast to the same barrier, as when we do a TMA into a
199- // tcgen05.mma 2CTA, we just register the expect in the lead barrier, as
200- // it is the only one that will receive the mbarrier signals
201- auto ctaId = nvgpu::ClusterCTAIdOp::create (rewriter, loc);
202- auto ctaIdInGroup = b.and_ (ctaId, b.i32_val (maskCGABroadcast));
203- pred = b.and_ (pred, b.icmp_eq (ctaIdInGroup, b.i32_val (0 )));
204- }
205-
206- ::mlir::triton::PTXBuilder ptxBuilder;
207- const std::string ptx =
202+ Value basePred = b.icmp_eq (id, b.i32_val (0 ));
203+ basePred = b.and_ (basePred, adaptor.getPred ());
204+ auto leaderCTAPred =
205+ LLVM::NVIDIA::getLeaderCTAPredicate (loc, rewriter, barrierTy);
206+ bool crossCluster = leaderCTAPred.has_value ();
207+ Value leaderPred =
208+ leaderCTAPred ? b.and_ (basePred, *leaderCTAPred) : basePred;
209+ Value leaderBarrierPtr = LLVM::NVIDIA::getLeaderAddress (
210+ loc, rewriter, smemObj.getBase (), barrierTy);
211+
212+ ::mlir::triton::PTXBuilder expectPtxBuilder;
213+ const std::string expectPtx =
208214 " @$0 mbarrier.arrive.expect_tx.shared::cta.b64 _, [$1], " +
209215 std::to_string (expectedBytes) + " ;" ;
210- auto &barSyncOp = *ptxBuilder .create (ptx );
211- barSyncOp ({ptxBuilder .newOperand (pred , " b" ),
212- ptxBuilder .newOperand (smemObj. getBase () , " r" )},
213- /* onlyAttachMLIRArgs=*/ true );
216+ auto &expectOp = *expectPtxBuilder .create (expectPtx );
217+ expectOp ({expectPtxBuilder .newOperand (leaderPred , " b" ),
218+ expectPtxBuilder .newOperand (leaderBarrierPtr , " r" )},
219+ /* onlyAttachMLIRArgs=*/ true );
214220 auto voidTy = void_ty (op->getContext ());
215- ptxBuilder.launch (rewriter, loc, voidTy);
221+ expectPtxBuilder.launch (rewriter, loc, voidTy);
222+
223+ if (crossCluster) {
224+ // Non-leader CTAs still contribute one arrival to the lead CTA barrier.
225+ auto nonLeaderPred = b.and_ (basePred, b.xor_ (leaderPred, b.true_val ()));
226+ ::mlir::triton::PTXBuilder arrivePtxBuilder;
227+ const std::string arrivePtx =
228+ " @$0 mbarrier.arrive.shared::cluster.b64 _, [$1], 1;" ;
229+ auto &arriveOp = *arrivePtxBuilder.create (arrivePtx);
230+ arriveOp ({arrivePtxBuilder.newOperand (nonLeaderPred, " b" ),
231+ arrivePtxBuilder.newOperand (leaderBarrierPtr, " r" )},
232+ /* onlyAttachMLIRArgs=*/ true );
233+ arrivePtxBuilder.launch (rewriter, loc, voidTy);
234+ }
235+
216236 rewriter.eraseOp (op);
217237 return success ();
218238 }
@@ -238,19 +258,9 @@ struct WaitBarrierOpConversion
238258 auto loc = op.getLoc ();
239259 auto b = TritonLLVMOpBuilder (loc, rewriter);
240260 auto pred = adaptor.getPred ();
241-
242- auto kBlock = StringAttr::get (ctx, " block" );
243- auto maskCGABroadcast =
244- toLinearLayout (barrierTy).getFreeVariableMasks ().lookup (kBlock );
245- if (maskCGABroadcast) {
246- // If several CTAs cast to the same barrier, as when we do a TMA into a
247- // tcgen05.mma 2CTA, we send all the signals to the lead CTA, so even if
248- // this barrier is waiting for zero bytes, no one will arrive on it. As
249- // such, we predicate it out
250- auto ctaId = nvgpu::ClusterCTAIdOp::create (rewriter, loc);
251- auto ctaIdInGroup = b.and_ (ctaId, b.i32_val (maskCGABroadcast));
252- pred = b.and_ (pred, b.icmp_eq (ctaIdInGroup, b.i32_val (0 )));
253- }
261+ if (auto leaderPred =
262+ LLVM::NVIDIA::getLeaderCTAPredicate (loc, rewriter, barrierTy))
263+ pred = b.and_ (pred, *leaderPred);
254264
255265 bool predicated = pred && !matchPattern (pred, m_NonZero ());
256266 std::string ptx;
@@ -323,29 +333,48 @@ struct ArriveBarrierOpConversion
323333 LogicalResult
324334 matchAndRewrite (triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor,
325335 ConversionPatternRewriter &rewriter) const override {
336+ auto loc = op.getLoc ();
337+ auto b = TritonLLVMOpBuilder (loc, rewriter);
338+ auto barrierTy = op.getAlloc ().getType ();
339+ auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
340+ loc, adaptor.getAlloc (),
341+ typeConverter->convertType (barrierTy.getElementType ()), rewriter);
342+
343+ // Arrive has block-level semantics, so we must synchronize
344+ // Technically, this should be MemBar's job but it can include TMEM
345+ // accesses which doesn't have a MemBar equivalent :/
346+ ttg::BarrierOp::create (rewriter, loc, ttg::AddrSpace::Local);
347+
348+ Value id = getThreadId (rewriter, loc);
349+ Value pred = b.icmp_eq (id, b.i32_val (0 ));
350+ if (op.getPred ())
351+ pred = b.and_ (pred, adaptor.getPred ());
352+
353+ bool isCrossCluster =
354+ LLVM::NVIDIA::getLeaderCTAPredicate (loc, rewriter, barrierTy)
355+ .has_value ();
356+
357+ Value barrierPtr = LLVM::NVIDIA::getLeaderAddress (
358+ loc, rewriter, smemObj.getBase (), barrierTy);
326359 // TODO: Add phase result as needed.
327360 std::stringstream ptxAsm;
328- ptxAsm << " @$0 mbarrier.arrive.shared::cta.b64 _, [$1]" ;
361+ ptxAsm << " @$0 mbarrier.arrive."
362+ << (isCrossCluster ? " shared::cluster" : " shared::cta" )
363+ << " .b64 _, [$1]" ;
329364 if (op.getCount () > 1 ) {
330365 ptxAsm << " , " << op.getCount ();
331366 }
332367 ptxAsm << " ;" ;
333368
334- TritonLLVMOpBuilder b (op.getLoc (), rewriter);
335- Value id = getThreadId (rewriter, op.getLoc ());
336- Value pred = b.icmp_eq (id, b.i32_val (0 ));
337- if (op.getPred ())
338- pred = b.and_ (pred, adaptor.getPred ());
339-
340369 PTXBuilder ptxBuilder;
341370 SmallVector<PTXBuilder::Operand *, 2 > operands = {
342371 ptxBuilder.newOperand (pred, " b" ),
343- ptxBuilder.newOperand (adaptor. getAlloc () , " r" )};
372+ ptxBuilder.newOperand (barrierPtr , " r" )};
344373
345374 auto arriveOp = *ptxBuilder.create (ptxAsm.str ());
346375 arriveOp (operands, /* onlyAttachMLIRArgs=*/ true );
347376 auto voidTy = void_ty (getContext ());
348- ptxBuilder.launch (rewriter, op. getLoc () , voidTy);
377+ ptxBuilder.launch (rewriter, loc , voidTy);
349378
350379 rewriter.eraseOp (op);
351380 return success ();
0 commit comments