44#include " triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55#include " triton/Conversion/TritonGPUToLLVM/Utility.h"
66
7+ #include < optional>
8+
79#include " triton/Analysis/Allocation.h"
810#include " triton/Dialect/Triton/IR/Types.h"
911#include " triton/Dialect/Triton/IR/Utility.h"
@@ -45,23 +47,18 @@ struct ConvertLayoutOpConversion
4547 LinearLayout srcLayout = toLinearLayout (srcTy);
4648 LinearLayout dstLayout = toLinearLayout (dstTy);
4749
48- StringAttr kBlock = str_attr (" block" );
49- StringAttr kWarp = str_attr (" warp" );
50- StringAttr kLane = str_attr (" lane" );
51- StringAttr kRegister = str_attr (" register" );
50+ auto kBlock = str_attr (" block" );
51+ auto kWarp = str_attr (" warp" );
52+ auto kLane = str_attr (" lane" );
53+ auto kRegister = str_attr (" register" );
5254
5355 assert (to_vector (conversion.getInDimNames ()) ==
5456 to_vector (conversion.getOutDimNames ()));
5557 auto dims = conversion.getInDimNames ();
56- if (llvm::is_contained (dims, kBlock )) {
57- // Case 1: Transfer between values in different CTAs.
58- // This requires moving values through distributed shared memory.
59- return rewriter.notifyMatchFailure (
60- op, " NYI: Transfer between different CTAs" );
61- } else if (llvm::is_contained (dims, kWarp )) {
62- // Case 2: Transfer between values in the same CTA, in which case we move
63- // values through shared memory.
64- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
58+ if (llvm::is_contained (dims, kBlock ) || llvm::is_contained (dims, kWarp )) {
59+ // Transfer between values in the same CTA, or across CTAs. We move values
60+ // through (distributed) shared memory.
61+ transferSwizzlingLocalMem (op, adaptor.getSrc (), rewriter);
6562 return success ();
6663 } else if (llvm::is_contained (dims, kLane )) {
6764 // Case 3. Transfer between values in the same warp, in which case we try
@@ -70,7 +67,7 @@ struct ConvertLayoutOpConversion
7067 if (cvtNeedsWarpShuffle (srcTy, dstTy))
7168 return transferWithinWarp (op, adaptor, rewriter);
7269
73- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
70+ transferSwizzlingLocalMem (op, adaptor.getSrc (), rewriter);
7471 return success ();
7572 } else if (llvm::is_contained (dims, kRegister )) {
7673 // Case 4. Transfer between values in the same thread, in which case we
@@ -90,7 +87,7 @@ struct ConvertLayoutOpConversion
9087 ConversionPatternRewriter &rewriter) const {
9188 MLIRContext *ctx = op.getContext ();
9289 auto loc = op.getLoc ();
93- StringAttr kRegister = str_attr (" register" );
90+ auto kRegister = str_attr (" register" );
9491 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
9592
9693 auto srcTy = op.getSrc ().getType ();
@@ -107,7 +104,7 @@ struct ConvertLayoutOpConversion
107104 return success ();
108105 }
109106
110- SmallVector<Value> transferWithinBlockSwizzlingImpl (
107+ SmallVector<Value> transferSwizzlingLocalMemImpl (
111108 Location loc, ConversionPatternRewriter &rewriter,
112109 const LinearLayout &srcLayout, const LinearLayout &dstLayout,
113110 ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
@@ -123,8 +120,8 @@ struct ConvertLayoutOpConversion
123120 return b.ptrtoint (llvmElemTyPtr, v).getResult ();
124121 }));
125122 auto outVals =
126- transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
127- newInVals, llvmElemTyPtr, smemBase);
123+ transferSwizzlingLocalMemImpl (loc, rewriter, srcLayout, dstLayout,
124+ newInVals, llvmElemTyPtr, smemBase);
128125 for (auto &v : outVals) {
129126 v = b.inttoptr (llvmElemTy, v);
130127 }
@@ -137,7 +134,7 @@ struct ConvertLayoutOpConversion
137134 auto i8ElemTy = i8_ty;
138135 auto newInVals = llvm::to_vector (llvm::map_range (
139136 inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
140- auto outVals = transferWithinBlockSwizzlingImpl (
137+ auto outVals = transferSwizzlingLocalMemImpl (
141138 loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
142139 for (auto &v : outVals) {
143140 v = b.trunc (llvmElemTy, v);
@@ -150,15 +147,15 @@ struct ConvertLayoutOpConversion
150147 if (!removeBroadcastSrc.isIdentity ()) {
151148 auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
152149 auto newInVals = removeBroadcastSrc.apply (inVals);
153- return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
154- newInVals, llvmElemTy, smemBase);
150+ return transferSwizzlingLocalMemImpl (loc, rewriter, prmtSrc, dstLayout,
151+ newInVals, llvmElemTy, smemBase);
155152 }
156153
157154 // Remove broadcasting in dst
158155 auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
159156 if (!removeBroadcastDst.isIdentity ()) {
160157 auto prmtDst = removeBroadcastDst.apply (dstLayout);
161- auto outVals = transferWithinBlockSwizzlingImpl (
158+ auto outVals = transferSwizzlingLocalMemImpl (
162159 loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
163160 return broadcastAs (outVals, dstLayout);
164161 }
@@ -170,6 +167,8 @@ struct ConvertLayoutOpConversion
170167
171168 // Extract reps from smem
172169 auto kReg = str_attr (" register" );
170+ auto kWarp = str_attr (" warp" );
171+ auto kBlock = str_attr (" block" );
173172 auto kReps = str_attr (" reps" );
174173 auto nReps = smem.getInDimSize (kReps );
175174 auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
@@ -191,8 +190,11 @@ struct ConvertLayoutOpConversion
191190 auto storeCvt = *divideRight (totalStoreCvt, reps);
192191 auto loadCvt = *divideRight (totalLoadCvt, reps);
193192 auto kOffset = str_attr (" offset" );
194- storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
195- loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
193+ auto nBlock = storeCvt.getInDimSize (kBlock );
194+ storeCvt = storeCvt.reshapeOuts (
195+ {{kOffset , storeCvt.getTotalOutDimSize () / nBlock}, {kBlock , nBlock}});
196+ loadCvt = loadCvt.reshapeOuts (
197+ {{kOffset , loadCvt.getTotalOutDimSize () / nBlock}, {kBlock , nBlock}});
196198
197199 auto tileSize = storeCvt.getInDimSize (kReg );
198200
@@ -201,28 +203,30 @@ struct ConvertLayoutOpConversion
201203 auto affineOffset = b.i32_val (0 );
202204 auto maskSpanAffineOffset = 0 ;
203205
204- bool isWarpSync = mlir::isCvtWarpSync (srcLayout, dstLayout);
205- for (int i = 0 ; i < nReps; ++i) {
206- if (i > 0 ) {
207- if (isWarpSync) {
208- targetInfo.warpSync (loc, rewriter);
209- } else {
210- targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
211- }
206+ bool isWarpSync = mlir::isCvtDimSync (srcLayout, dstLayout, kWarp );
207+ bool isBlockSync = mlir::isCvtDimSync (srcLayout, dstLayout, kBlock );
208+ auto emitBarrier = [&]() {
209+ if (isWarpSync) {
210+ targetInfo.warpSync (loc, rewriter);
211+ } else if (isBlockSync) {
212+ targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
213+ } else {
214+ targetInfo.clusterBarrier (loc, rewriter);
212215 }
216+ };
217+
218+ for (int i = 0 ; i < nReps; ++i) {
219+ if (i > 0 )
220+ emitBarrier ();
213221 auto tileInVals =
214222 ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
215223 // Store
216224 lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
217225 /* paddingShifts=*/ {}, affineOffset, maskSpanAffineOffset,
218226 rewriter, targetInfo);
219- if (isWarpSync) {
220- targetInfo.warpSync (loc, rewriter);
221- } else {
222- targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
223- }
227+ emitBarrier ();
224228 // Load
225- SmallVector<Value> tileOutVals = lowerLdStShared (
229+ auto tileOutVals = lowerLdStShared (
226230 loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /* paddingShifts=*/ {},
227231 affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
228232 llvm::append_range (outVals, tileOutVals);
@@ -233,30 +237,21 @@ struct ConvertLayoutOpConversion
233237 return outVals;
234238 }
235239
236- void transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
237- ConversionPatternRewriter &rewriter) const {
240+ void transferSwizzlingLocalMem (ConvertLayoutOp op, Value src,
241+ ConversionPatternRewriter &rewriter) const {
238242 auto loc = op.getLoc ();
239243 auto *ctx = op.getContext ();
240244 auto srcTy = op.getSrc ().getType ();
241245 auto dstTy = op.getType ();
242246
243- // Remove the kBlock dimension from the layout as it's the identity in the
244- // cvt
245247 auto srcLayout = toLinearLayout (srcTy);
246248 auto dstLayout = toLinearLayout (dstTy);
247- auto kReg = str_attr (" register" );
248- auto kLane = str_attr (" lane" );
249- auto kWarp = str_attr (" warp" );
250- srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
251- to_vector (srcLayout.getOutDimNames ()));
252- dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
253- to_vector (dstLayout.getOutDimNames ()));
254249
255250 auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
256251 auto smemBase =
257252 LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
258253 auto inVals = unpackLLElements (loc, src, rewriter);
259- auto outVals = transferWithinBlockSwizzlingImpl (
254+ auto outVals = transferSwizzlingLocalMemImpl (
260255 loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
261256
262257 Value result =
@@ -273,8 +268,8 @@ struct ConvertLayoutOpConversion
273268 auto b = TritonLLVMOpBuilder (loc, rewriter);
274269 auto srcTy = op.getSrc ().getType ();
275270 auto dstTy = op.getType ();
276- StringAttr kReg = str_attr (" register" );
277- StringAttr kLane = str_attr (" lane" );
271+ auto kReg = str_attr (" register" );
272+ auto kLane = str_attr (" lane" );
278273 auto elemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
279274 int bitwidth = getIntOrFloatOrPtrBitWidth (elemTy);
280275
@@ -431,8 +426,8 @@ struct ConvertLayoutOpConversion
431426 ArrayRef<TranspositionInfo> mixedTranspositions) const {
432427 auto *ctx = rewriter.getContext ();
433428 auto b = TritonLLVMOpBuilder (loc, rewriter);
434- StringAttr kReg = str_attr (" register" );
435- StringAttr kLane = str_attr (" lane" );
429+ auto kReg = str_attr (" register" );
430+ auto kLane = str_attr (" lane" );
436431
437432 SmallVector<Value> vals (inVals.begin (), inVals.end ());
438433 int m = mixedTranspositions.size ();
0 commit comments