44#include " triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55#include " triton/Conversion/TritonGPUToLLVM/Utility.h"
66
7+ #include < optional>
8+
79#include " triton/Analysis/Allocation.h"
810#include " triton/Dialect/Triton/IR/Types.h"
911#include " triton/Dialect/Triton/IR/Utility.h"
@@ -45,26 +47,20 @@ struct ConvertLayoutOpConversion
4547 LinearLayout srcLayout = toLinearLayout (srcTy);
4648 LinearLayout dstLayout = toLinearLayout (dstTy);
4749
48- StringAttr kBlock = str_attr (" block" );
49- StringAttr kWarp = str_attr (" warp" );
50- StringAttr kLane = str_attr (" lane" );
51- StringAttr kRegister = str_attr (" register" );
50+ auto kBlock = str_attr (" block" );
51+ auto kWarp = str_attr (" warp" );
52+ auto kLane = str_attr (" lane" );
53+ auto kRegister = str_attr (" register" );
5254
5355 auto dims = conversion.getInDimNames ();
5456 bool alwaysUseWarpShuffle = cvtAlwaysUseWarpShuffle (op);
55- assert (!alwaysUseWarpShuffle || (!llvm::is_contained (dims, kBlock ) &&
56- !llvm::is_contained (dims, kWarp )));
5757 assert (to_vector (conversion.getInDimNames ()) ==
5858 to_vector (conversion.getOutDimNames ()));
59- if (llvm::is_contained (dims, kBlock )) {
60- // Case 1: Transfer between values in different CTAs.
61- // This requires moving values through distributed shared memory.
62- return rewriter.notifyMatchFailure (
63- op, " NYI: Transfer between different CTAs" );
64- } else if (llvm::is_contained (dims, kWarp )) {
65- // Case 2: Transfer between values in the same CTA, in which case we move
66- // values through shared memory.
67- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
59+ if (llvm::is_contained (dims, kBlock ) || llvm::is_contained (dims, kWarp )) {
60+ assert (!alwaysUseWarpShuffle);
61+ // Transfer between values in the same CTA, or across CTAs. We move values
62+ // through (distributed) shared memory.
63+ transferSwizzlingLocalMem (op, adaptor.getSrc (), rewriter);
6864 return success ();
6965 } else if (llvm::is_contained (dims, kLane )) {
7066 // Case 3. Transfer between values in the same warp, in which case we try
@@ -73,7 +69,7 @@ struct ConvertLayoutOpConversion
7369 if (cvtNeedsWarpShuffle (srcTy, dstTy) || alwaysUseWarpShuffle)
7470 return transferWithinWarp (op, adaptor, rewriter);
7571
76- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
72+ transferSwizzlingLocalMem (op, adaptor.getSrc (), rewriter);
7773 return success ();
7874 } else if (llvm::is_contained (dims, kRegister )) {
7975 // Case 4. Transfer between values in the same thread, in which case we
@@ -93,7 +89,7 @@ struct ConvertLayoutOpConversion
9389 ConversionPatternRewriter &rewriter) const {
9490 MLIRContext *ctx = op.getContext ();
9591 auto loc = op.getLoc ();
96- StringAttr kRegister = str_attr (" register" );
92+ auto kRegister = str_attr (" register" );
9793 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
9894
9995 auto srcTy = op.getSrc ().getType ();
@@ -110,7 +106,7 @@ struct ConvertLayoutOpConversion
110106 return success ();
111107 }
112108
113- SmallVector<Value> transferWithinBlockSwizzlingImpl (
109+ SmallVector<Value> transferSwizzlingLocalMemImpl (
114110 Location loc, ConversionPatternRewriter &rewriter,
115111 const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116112 ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
@@ -126,8 +122,8 @@ struct ConvertLayoutOpConversion
126122 return b.ptrtoint (llvmElemTyPtr, v).getResult ();
127123 }));
128124 auto outVals =
129- transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
130- newInVals, llvmElemTyPtr, smemBase);
125+ transferSwizzlingLocalMemImpl (loc, rewriter, srcLayout, dstLayout,
126+ newInVals, llvmElemTyPtr, smemBase);
131127 for (auto &v : outVals) {
132128 v = b.inttoptr (llvmElemTy, v);
133129 }
@@ -140,7 +136,7 @@ struct ConvertLayoutOpConversion
140136 auto i8ElemTy = i8_ty;
141137 auto newInVals = llvm::to_vector (llvm::map_range (
142138 inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
143- auto outVals = transferWithinBlockSwizzlingImpl (
139+ auto outVals = transferSwizzlingLocalMemImpl (
144140 loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145141 for (auto &v : outVals) {
146142 v = b.trunc (llvmElemTy, v);
@@ -153,15 +149,15 @@ struct ConvertLayoutOpConversion
153149 if (!removeBroadcastSrc.isIdentity ()) {
154150 auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
155151 auto newInVals = removeBroadcastSrc.apply (inVals);
156- return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
157- newInVals, llvmElemTy, smemBase);
152+ return transferSwizzlingLocalMemImpl (loc, rewriter, prmtSrc, dstLayout,
153+ newInVals, llvmElemTy, smemBase);
158154 }
159155
160156 // Remove broadcasting in dst
161157 auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
162158 if (!removeBroadcastDst.isIdentity ()) {
163159 auto prmtDst = removeBroadcastDst.apply (dstLayout);
164- auto outVals = transferWithinBlockSwizzlingImpl (
160+ auto outVals = transferSwizzlingLocalMemImpl (
165161 loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166162 return broadcastAs (outVals, dstLayout);
167163 }
@@ -173,6 +169,8 @@ struct ConvertLayoutOpConversion
173169
174170 // Extract reps from smem
175171 auto kReg = str_attr (" register" );
172+ auto kWarp = str_attr (" warp" );
173+ auto kBlock = str_attr (" block" );
176174 auto kReps = str_attr (" reps" );
177175 auto nReps = smem.getInDimSize (kReps );
178176 auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
@@ -194,8 +192,11 @@ struct ConvertLayoutOpConversion
194192 auto storeCvt = *divideRight (totalStoreCvt, reps);
195193 auto loadCvt = *divideRight (totalLoadCvt, reps);
196194 auto kOffset = str_attr (" offset" );
197- storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
198- loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
195+ auto nBlock = storeCvt.getInDimSize (kBlock );
196+ storeCvt = storeCvt.reshapeOuts (
197+ {{kOffset , storeCvt.getTotalOutDimSize () / nBlock}, {kBlock , nBlock}});
198+ loadCvt = loadCvt.reshapeOuts (
199+ {{kOffset , loadCvt.getTotalOutDimSize () / nBlock}, {kBlock , nBlock}});
199200
200201 auto tileSize = storeCvt.getInDimSize (kReg );
201202
@@ -204,28 +205,30 @@ struct ConvertLayoutOpConversion
204205 auto affineOffset = b.i32_val (0 );
205206 auto maskSpanAffineOffset = 0 ;
206207
207- bool isWarpSync = mlir::isCvtWarpSync (srcLayout, dstLayout);
208- for (int i = 0 ; i < nReps; ++i) {
209- if (i > 0 ) {
210- if (isWarpSync) {
211- targetInfo.warpSync (loc, rewriter);
212- } else {
213- targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
214- }
208+ bool isWarpSync = mlir::isCvtDimSync (srcLayout, dstLayout, kWarp );
209+ bool isBlockSync = mlir::isCvtDimSync (srcLayout, dstLayout, kBlock );
210+ auto emitBarrier = [&]() {
211+ if (isWarpSync) {
212+ targetInfo.warpSync (loc, rewriter);
213+ } else if (isBlockSync) {
214+ targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
215+ } else {
216+ targetInfo.clusterBarrier (loc, rewriter);
215217 }
218+ };
219+
220+ for (int i = 0 ; i < nReps; ++i) {
221+ if (i > 0 )
222+ emitBarrier ();
216223 auto tileInVals =
217224 ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
218225 // Store
219226 lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
220227 /* paddingShifts=*/ {}, affineOffset, maskSpanAffineOffset,
221228 rewriter, targetInfo);
222- if (isWarpSync) {
223- targetInfo.warpSync (loc, rewriter);
224- } else {
225- targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
226- }
229+ emitBarrier ();
227230 // Load
228- SmallVector<Value> tileOutVals = lowerLdStShared (
231+ auto tileOutVals = lowerLdStShared (
229232 loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /* paddingShifts=*/ {},
230233 affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
231234 llvm::append_range (outVals, tileOutVals);
@@ -236,30 +239,21 @@ struct ConvertLayoutOpConversion
236239 return outVals;
237240 }
238241
239- void transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
240- ConversionPatternRewriter &rewriter) const {
242+ void transferSwizzlingLocalMem (ConvertLayoutOp op, Value src,
243+ ConversionPatternRewriter &rewriter) const {
241244 auto loc = op.getLoc ();
242245 auto *ctx = op.getContext ();
243246 auto srcTy = op.getSrc ().getType ();
244247 auto dstTy = op.getType ();
245248
246- // Remove the kBlock dimension from the layout as it's the identity in the
247- // cvt
248249 auto srcLayout = toLinearLayout (srcTy);
249250 auto dstLayout = toLinearLayout (dstTy);
250- auto kReg = str_attr (" register" );
251- auto kLane = str_attr (" lane" );
252- auto kWarp = str_attr (" warp" );
253- srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
254- to_vector (srcLayout.getOutDimNames ()));
255- dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
256- to_vector (dstLayout.getOutDimNames ()));
257251
258252 auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
259253 auto smemBase =
260254 LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
261255 auto inVals = unpackLLElements (loc, src, rewriter);
262- auto outVals = transferWithinBlockSwizzlingImpl (
256+ auto outVals = transferSwizzlingLocalMemImpl (
263257 loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
264258
265259 Value result =
@@ -276,8 +270,8 @@ struct ConvertLayoutOpConversion
276270 auto b = TritonLLVMOpBuilder (loc, rewriter);
277271 auto srcTy = op.getSrc ().getType ();
278272 auto dstTy = op.getType ();
279- StringAttr kReg = str_attr (" register" );
280- StringAttr kLane = str_attr (" lane" );
273+ auto kReg = str_attr (" register" );
274+ auto kLane = str_attr (" lane" );
281275 auto elemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
282276 int bitwidth = getIntOrFloatOrPtrBitWidth (elemTy);
283277
@@ -434,8 +428,8 @@ struct ConvertLayoutOpConversion
434428 ArrayRef<TranspositionInfo> mixedTranspositions) const {
435429 auto *ctx = rewriter.getContext ();
436430 auto b = TritonLLVMOpBuilder (loc, rewriter);
437- StringAttr kReg = str_attr (" register" );
438- StringAttr kLane = str_attr (" lane" );
431+ auto kReg = str_attr (" register" );
432+ auto kLane = str_attr (" lane" );
439433
440434 SmallVector<Value> vals (inVals.begin (), inVals.end ());
441435 int m = mixedTranspositions.size ();
0 commit comments