44#include " triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55#include " triton/Conversion/TritonGPUToLLVM/Utility.h"
66
7+ #include < optional>
8+
79#include " triton/Analysis/Allocation.h"
810#include " triton/Dialect/Triton/IR/Types.h"
911#include " triton/Dialect/Triton/IR/Utility.h"
@@ -53,15 +55,10 @@ struct ConvertLayoutOpConversion
5355 assert (to_vector (conversion.getInDimNames ()) ==
5456 to_vector (conversion.getOutDimNames ()));
5557 auto dims = conversion.getInDimNames ();
56- if (llvm::is_contained (dims, kBlock )) {
57- // Case 1: Transfer between values in different CTAs.
58- // This requires moving values through distributed shared memory.
59- return rewriter.notifyMatchFailure (
60- op, " NYI: Transfer between different CTAs" );
61- } else if (llvm::is_contained (dims, kWarp )) {
62- // Case 2: Transfer between values in the same CTA, in which case we move
63- // values through shared memory.
64- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
58+ if (llvm::is_contained (dims, kBlock ) || llvm::is_contained (dims, kWarp )) {
59+ // Transfer between values in the same CTA, or across CTAs. We move values
60+ // through (distributed) shared memory.
61+ transferSwizzlingLocalMem (op, adaptor.getSrc (), rewriter);
6562 return success ();
6663 } else if (llvm::is_contained (dims, kLane )) {
6764 // Case 3. Transfer between values in the same warp, in which case we try
@@ -70,7 +67,7 @@ struct ConvertLayoutOpConversion
7067 if (cvtNeedsWarpShuffle (srcTy, dstTy))
7168 return transferWithinWarp (op, adaptor, rewriter);
7269
73- transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
70+ transferSwizzlingLocalMem (op, adaptor.getSrc (), rewriter);
7471 return success ();
7572 } else if (llvm::is_contained (dims, kRegister )) {
7673 // Case 4. Transfer between values in the same thread, in which case we
@@ -107,7 +104,7 @@ struct ConvertLayoutOpConversion
107104 return success ();
108105 }
109106
110- SmallVector<Value> transferWithinBlockSwizzlingImpl (
107+ SmallVector<Value> transferSwizzlingLocalMemImpl (
111108 Location loc, ConversionPatternRewriter &rewriter,
112109 const LinearLayout &srcLayout, const LinearLayout &dstLayout,
113110 ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
@@ -123,8 +120,8 @@ struct ConvertLayoutOpConversion
123120 return b.ptrtoint (llvmElemTyPtr, v).getResult ();
124121 }));
125122 auto outVals =
126- transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
127- newInVals, llvmElemTyPtr, smemBase);
123+ transferSwizzlingLocalMemImpl (loc, rewriter, srcLayout, dstLayout,
124+ newInVals, llvmElemTyPtr, smemBase);
128125 for (auto &v : outVals) {
129126 v = b.inttoptr (llvmElemTy, v);
130127 }
@@ -137,7 +134,7 @@ struct ConvertLayoutOpConversion
137134 auto i8ElemTy = i8_ty;
138135 auto newInVals = llvm::to_vector (llvm::map_range (
139136 inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
140- auto outVals = transferWithinBlockSwizzlingImpl (
137+ auto outVals = transferSwizzlingLocalMemImpl (
141138 loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
142139 for (auto &v : outVals) {
143140 v = b.trunc (llvmElemTy, v);
@@ -150,26 +147,29 @@ struct ConvertLayoutOpConversion
150147 if (!removeBroadcastSrc.isIdentity ()) {
151148 auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
152149 auto newInVals = removeBroadcastSrc.apply (inVals);
153- return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
154- newInVals, llvmElemTy, smemBase);
150+ return transferSwizzlingLocalMemImpl (loc, rewriter, prmtSrc, dstLayout,
151+ newInVals, llvmElemTy, smemBase);
155152 }
156153
157154 // Remove broadcasting in dst
158155 auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
159156 if (!removeBroadcastDst.isIdentity ()) {
160157 auto prmtDst = removeBroadcastDst.apply (dstLayout);
161- auto outVals = transferWithinBlockSwizzlingImpl (
158+ auto outVals = transferSwizzlingLocalMemImpl (
162159 loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
163160 return broadcastAs (outVals, dstLayout);
164161 }
165162
166163 // At this point we have a type that's at least 8-bit
167164 // and we don't have broadcasting in the registers
165+
168166 auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
169167 auto smem = optimalSwizzlingLdSt (srcLayout, dstLayout, bitwidth);
170168
171169 // Extract reps from smem
172170 auto kReg = str_attr (" register" );
171+ auto kWarp = StringAttr::get (ctx, " warp" );
172+ auto kBlock = StringAttr::get (ctx, " block" );
173173 auto kReps = str_attr (" reps" );
174174 auto nReps = smem.getInDimSize (kReps );
175175 auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
@@ -191,8 +191,11 @@ struct ConvertLayoutOpConversion
191191 auto storeCvt = *divideRight (totalStoreCvt, reps);
192192 auto loadCvt = *divideRight (totalLoadCvt, reps);
193193 auto kOffset = str_attr (" offset" );
194- storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
195- loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
194+ auto nBlock = storeCvt.getInDimSize (kBlock );
195+ storeCvt = storeCvt.reshapeOuts (
196+ {{kOffset , storeCvt.getTotalOutDimSize () / nBlock}, {kBlock , nBlock}});
197+ loadCvt = loadCvt.reshapeOuts (
198+ {{kOffset , loadCvt.getTotalOutDimSize () / nBlock}, {kBlock , nBlock}});
196199
197200 auto tileSize = storeCvt.getInDimSize (kReg );
198201
@@ -201,28 +204,30 @@ struct ConvertLayoutOpConversion
201204 auto affineOffset = b.i32_val (0 );
202205 auto maskSpanAffineOffset = 0 ;
203206
204- bool isWarpSync = mlir::isCvtWarpSync (srcLayout, dstLayout);
205- for (int i = 0 ; i < nReps; ++i) {
206- if (i > 0 ) {
207- if (isWarpSync) {
208- targetInfo.warpSync (loc, rewriter);
209- } else {
210- targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
211- }
207+ bool isWarpSync = mlir::isCvtDimSync (srcLayout, dstLayout, kWarp );
208+ bool isBlockSync = mlir::isCvtDimSync (srcLayout, dstLayout, kBlock );
209+ auto emitBarrier = [&]() {
210+ if (isWarpSync) {
211+ targetInfo.warpSync (loc, rewriter);
212+ } else if (isBlockSync) {
213+ targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
214+ } else {
215+ targetInfo.clusterBarrier (loc, rewriter);
212216 }
217+ };
218+
219+ for (int i = 0 ; i < nReps; ++i) {
220+ if (i > 0 )
221+ emitBarrier ();
213222 auto tileInVals =
214223 ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
215224 // Store
216225 lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
217226 /* paddingShifts=*/ {}, affineOffset, maskSpanAffineOffset,
218227 rewriter, targetInfo);
219- if (isWarpSync) {
220- targetInfo.warpSync (loc, rewriter);
221- } else {
222- targetInfo.barrier (loc, rewriter, triton::gpu::AddrSpace::Local);
223- }
228+ emitBarrier ();
224229 // Load
225- SmallVector<Value> tileOutVals = lowerLdStShared (
230+ auto tileOutVals = lowerLdStShared (
226231 loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /* paddingShifts=*/ {},
227232 affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
228233 llvm::append_range (outVals, tileOutVals);
@@ -233,30 +238,21 @@ struct ConvertLayoutOpConversion
233238 return outVals;
234239 }
235240
236- void transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
237- ConversionPatternRewriter &rewriter) const {
241+ void transferSwizzlingLocalMem (ConvertLayoutOp op, Value src,
242+ ConversionPatternRewriter &rewriter) const {
238243 auto loc = op.getLoc ();
239244 auto *ctx = op.getContext ();
240245 auto srcTy = op.getSrc ().getType ();
241246 auto dstTy = op.getType ();
242247
243- // Remove the kBlock dimension from the layout as it's the identity in the
244- // cvt
245248 auto srcLayout = toLinearLayout (srcTy);
246249 auto dstLayout = toLinearLayout (dstTy);
247- auto kReg = str_attr (" register" );
248- auto kLane = str_attr (" lane" );
249- auto kWarp = str_attr (" warp" );
250- srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
251- to_vector (srcLayout.getOutDimNames ()));
252- dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
253- to_vector (dstLayout.getOutDimNames ()));
254250
255251 auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
256252 auto smemBase =
257253 LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
258254 auto inVals = unpackLLElements (loc, src, rewriter);
259- auto outVals = transferWithinBlockSwizzlingImpl (
255+ auto outVals = transferSwizzlingLocalMemImpl (
260256 loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
261257
262258 Value result =
0 commit comments