@@ -124,26 +124,28 @@ int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
124124// - Padding intervals must be multiples of 1024 bytes for 16-byte loads.
125125// To avoid bank conflicts when reading tensors in MFMA layout, we stagger
126126// continuous rows (non contig dimension) by adding padding that shifts their
127- // start addresses to different shared memory banks. Generally it's enough to
128- // pad 16 continous rows (see exception below for mfma32 kContig). Therefore, we
129- // implement a linear mapping from logical tensor elements to shared memory
130- // offsets that:
131- // - Strides 16 consecutive rows by 1024 bytes in shared memory.
132- // - Fills "holes" by rows which are a multiple of 16
133- // For example, if each row is 256 bytes, four rows are required to fill the
134- // hole. The resulting reordering of rows in logical order is:
135- // [r0, r16, r32, r48, r1, row17, row33, row49, row2, row18, ...]
136- // Corresponding byte offsets for these rows are:
137- // [0, 256, 512, 768, 1024, ...]
138- // This approach naturally generalizes to other row sizes. For example, with
139- // 128-byte rows:
140- // Logical row order: [r0, r16, r32, r48, r64, r80, r96, r112, r1, r17, ...]
141- // Byte offsets: [0, 128, 256, 384, ..., 1024, ...]
142- // Since padding is applied in groups of 16 rows, the total data size for this
143- // layout must be at least 16 KB (16 * 1024 bytes).
127+ // start addresses to different shared memory banks.
128+ // take Mx64xbf16, k contiguous, kWidth=8, for example: (rX stands for row X)
129+ // padding here is set to 16 elements (32 bytes) to avoid bank conflicts
130+ // we can pack r0,r4,r8,r12,r16,r20,r24,r28 to compose a contiguous tile
131+ // r0[0:8), r0[8:16),
132+ // r1[0:8), r1[8:16),
133+ // r2[0:8), r2[8:16),
134+ // r3[0:8), r3[8:16),
135+ // r4[0:8), r4[8:16),
136+ // r5[0:8), r5[8:16),
137+ // r6[0:8), r6[8:16),
138+ // r7[0:8), r7[8:16),
139+ // r8[0:8), r8[8:16),
140+ // when composing padded layout, we first assemble the rows that are continuous.
141+ // in LDS, the rows are arranged as below
142+ // r0, r4, r8, r12, r16, r20, r24, r28
143+ // pad, r1, r5, r9, r13, r17, r21, r25
144+ // r29, pad, r2, r6, r10, r14, r18, r22
145+ // r26, r30, pad, r3 ....
144146ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4 (
145147 ttg::DotOperandEncodingAttr dotOpEnc, ttg::TensorOrMemDesc srcTy,
146- ArrayRef<unsigned > sharedOrder, bool useAsyncCopy) {
148+ ArrayRef<unsigned > sharedOrder, bool useAsyncCopy, unsigned warpSize ) {
147149 auto *ctx = srcTy.getContext ();
148150
149151 // NYI: padded layouts for tt.load/local_write which is more flexible
@@ -165,21 +167,12 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
165167
166168 unsigned bitWidth = getIntOrFloatOrPtrBitWidth (srcTy.getElementType ());
167169 unsigned elemByteWidth = std::max (bitWidth / 8u , 1u );
168- auto loadBytes = shape[0 ] * shape[1 ] * elemByteWidth;
169- if (loadBytes < 16384 ) {
170- return {};
171- }
172170
173171 // NYI: dtypes != 16bit
174172 if (elemByteWidth != 2 ) {
175173 return {};
176174 }
177175
178- // NYI: requires different stride factor since we stride by 16 rows
179- if (std::min (shape[0 ], shape[1 ]) < 16 ) {
180- return {};
181- }
182-
183176 auto operandIdx = dotOpEnc.getOpIdx ();
184177 auto kWidth = dotOpEnc.getKWidth ();
185178 int kDimIndex = operandIdx == 0 ? 1 : 0 ;
@@ -203,87 +196,94 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
203196
204197 // Determine row(contig) size
205198 unsigned contigDim = isKContig ? kDim : nonKDim;
199+ unsigned nonContigDim = isKContig ? nonKDim : kDim ;
200+
201+ // padding to avoid bank conflict
202+ // For ds_read_b128. Lanes access LDS in 4 pairs of 16 lanes. we have 64 banks
203+ // and each lane loads 4 banks. These lane groups are:
204+ // 1: 0-3, 12-15, 20-23, 24-27
205+ // 2: 4-7, 8-11, 16-19, 28-31
206+ // The upper half of the lanes follow the same pattern.
207+ // For ds_read_b64, it splits conseuctive lanes into 2 groups which access LDS
208+ // one after another
209+ unsigned padding = 0 ;
210+ if (isKContig) {
211+ padding = mfmaNonKDim == 16 ? (kWidth * 2 ) : kWidth ;
212+ } else {
213+ padding = mfmaNonKDim == 16 ? 16 : 32 ;
214+ }
215+ constexpr unsigned vecSize = 8 ; // in favor of dwordX4
216+ unsigned contigLanes = contigDim / vecSize;
217+ unsigned wrap = std::min (contigDim, 128u ) / padding;
218+ unsigned requiredDim = warpSize / contigLanes * wrap;
219+ if (nonContigDim < requiredDim) {
220+ return {};
221+ }
206222
207- // Clamp contigSize to 1024 bytes to have space for at least 16 rows per sub
208- // tile (16KB) and simply repeat the tile to the full tensor size.
209- contigDim = std::min (1024 / elemByteWidth, contigDim);
223+ // Use 16 rows wrap if block large enough
224+ bool useBestWrap = false ;
225+ unsigned bestWrap = 16 ;
226+ if (nonContigDim >= warpSize / contigLanes * bestWrap && bestWrap > wrap) {
227+ useBestWrap = true ;
228+ wrap = bestWrap;
229+ }
210230
211231 // We create linear bases mapping from [contigDim, nonContigDim] -> offset,
212- // representing the row reordering as described above
213232 std::vector<std::vector<int >> bases;
233+
214234 // Keep contigSize numbers of elments contiguous in shared memory
215235 for (int elemLog2 = 0 ; elemLog2 < llvm::Log2_32 (contigDim); elemLog2++)
216236 bases.push_back ({1 << elemLog2, 0 });
217237
218- // Add strided rows (by 16) to pad to 1024bytes
219- auto requiredNumBases = llvm::Log2_32 (1024U / elemByteWidth);
220- for (int rowBase = llvm::Log2_32 (16 ); bases.size () < requiredNumBases;
238+ // Add rows strided which has the same start offset
239+ unsigned paddingInterval = warpSize * vecSize;
240+ unsigned requiredNumBases = llvm::Log2_32 (paddingInterval);
241+ int rowBase = 0 ;
242+ for (rowBase = llvm::Log2_32 (wrap); bases.size () < requiredNumBases;
221243 rowBase++)
222244 bases.push_back ({0 , 1 << rowBase});
223245
224- // Add rows 1..16 afterwards to complete the tile
225- for (int rowLog2 = 0 ; rowLog2 < llvm::Log2_32 (16 ); rowLog2++)
246+ // Add rows [0, wrap]
247+ for (int rowLog2 = 0 ; rowLog2 < llvm::Log2_32 (wrap ); rowLog2++)
226248 bases.push_back ({0 , 1 << rowLog2});
227249
228- // Compute required padding (in bytes) to avoid conflicts when accessing rows
229- unsigned paddingBytes = 0 ;
230-
231- // To compute the required amount of padding to avoid bank conflicts we look
232- // at the number of contiguous bytes loaded for a single row this directly
233- // gives us the padding we require. Note for contigBytesPerLane == 16 we use a
234- // different mfma layout (wide) compared to contigBytesPerLane == 8 (narrow)
235- int contigBytesPerLane = kWidth * elemByteWidth;
236- bool useWideLayout = contigBytesPerLane == 16 ;
237- if (isKContig) {
238- // For wide layouts we will use ds_read_b128. Lanes access LDS
239- // (bank conflicts) in 4 pairs of 16 lanes since we have 64 banks and each
240- // lane loads 4 banks. These (lane)groups are:
241- // 1: 0-3, 12-15, 20-23, 24-27
242- // 2: 4-7, 8-11, 16-19, 28-31
243- // The upper half of the lanes follow the same pattern.
244- // For narrow layouts we will use ds_read_b64 which splits conseuctive
245- // lanes into 2 groups which access LDS one after another
246-
247- if (mfmaNonKDim == 16 ) {
248- // For wide layouts lane groups read 32 contiguous bytes
249- // For narrow layouts lane groups load 8 contiguous bytes
250- paddingBytes = useWideLayout ? 32 : 8 ;
251- }
250+ // Add remaining rows
251+ for (; rowBase < llvm::Log2_32 (nonContigDim); rowBase++)
252+ bases.push_back ({0 , 1 << rowBase});
252253
253- if (mfmaNonKDim == 32 ) {
254- // For mfma32 32 lanes read 32 continuous rows. So for narrow layouts we
255- // read 8 contiguous bytes and for wide layouts 16 bytes.
256- paddingBytes = useWideLayout ? 16 : 8 ;
257-
258- // For narrow layouts we need to shift every 16th row to the other half of
259- // shared memory banks to read from all banks. For the wide layout we need
260- // to ensure every 16th rows start at the same bank so lane groups access
261- // different banks. This is done by swapping the bases representing offset
262- // 256 (64banks) for wide layouts or 128 (32banks) for narrow layouts with
263- // the base of the "16th" row which is after log2(contigDim) bases.
264- int offsetBytes = useWideLayout ? 256 : 128 ;
265- int offsetIndex = llvm::Log2_32 (offsetBytes);
266- int row16Index = llvm::Log2_32 (contigDim);
267- assert (row16Index < bases.size ());
268- assert (offsetIndex < bases.size ());
269- std::swap (bases[offsetIndex], bases[row16Index]);
270- }
271- } else {
272- if (mfmaNonKDim == 16 ) {
273- // For mfma16 lane groups read 32 contiguous bytes
274- paddingBytes = 32 ;
275- if (useWideLayout) {
276- // For for the wide layout lane groups wrap at row 8 so we have to
277- // exchange row4 and row8 to avoid conflicts (last two bases)
278- std::swap (bases[bases.size () - 1 ], bases[bases.size () - 2 ]);
279- }
280- } else if (mfmaNonKDim == 32 ) {
281- // For mfma32 lane groups read 64 contiguous bytes
282- paddingBytes = 64 ;
254+ // Fixup for nonKContig and mfma16
255+ if (!isKContig && mfmaNonKDim == 16 ) {
256+ unsigned row4 = 0 ;
257+ unsigned row8 = 0 ;
258+ for (unsigned i = 0 ; i < bases.size (); i++) {
259+ if (bases[i][1 ] == 8 )
260+ row8 = i;
261+ if (bases[i][1 ] == 4 )
262+ row4 = i;
283263 }
264+ assert (row4 != 0 && row8 != 0 );
265+ // lane groups wrap at row8, so we have to exchange
266+ // row4 and row8 to avoid bank conflict
267+ std::swap (bases[row4], bases[row8]);
284268 }
285269
286- assert (paddingBytes != 0 );
270+ // Fixup for KContig and mfma32 when reordered rows can not fit in 64banks
271+ if (isKContig && mfmaNonKDim == 32 && useBestWrap && kDim < 128 ) {
272+ bool useWideLayout = kWidth == 8 ;
273+
274+ // For narrow layouts we need to shift every 16th row to the other half of
275+ // shared memory banks to read from all banks. For the wide layout we need
276+ // to ensure every 16th rows start at the same bank so lane groups access
277+ // different banks. This is done by swapping the bases representing offset
278+ // 256 (64banks) for wide layouts or 128 (32banks) for narrow layouts with
279+ // the base of the "16th" row which is after log2(contigDim) bases.
280+ int offsetBytes = useWideLayout ? 256 : 128 ;
281+ int offsetIndex = llvm::Log2_32 (offsetBytes);
282+ int row16Index = llvm::Log2_32 (contigDim);
283+ assert (row16Index < bases.size ());
284+ assert (offsetIndex < bases.size ());
285+ std::swap (bases[offsetIndex], bases[row16Index]);
286+ }
287287
288288 // Swap bases to match srcTy dimension order
289289 if ((isKContig && kDimIndex == 1 ) || (!isKContig && kDimIndex == 0 )) {
@@ -300,10 +300,8 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
300300 linearComponent = triton::gpu::combineCtaCgaWithShape (
301301 linearComponent, cgaLayout, srcTy.getShape ());
302302
303- unsigned paddingInterval = 1024 / elemByteWidth;
304- unsigned paddingInElems = paddingBytes / elemByteWidth;
305- return ttg::PaddedSharedEncodingAttr::get (
306- ctx, {{paddingInterval, paddingInElems}}, std::move (linearComponent));
303+ return ttg::PaddedSharedEncodingAttr::get (ctx, {{paddingInterval, padding}},
304+ std::move (linearComponent));
307305}
308306
309307ttg::PaddedSharedEncodingAttr
@@ -313,8 +311,9 @@ composePaddedLayout(const tt::AMD::TargetInfo &targetInfo,
313311 bool useAsyncCopy) {
314312 if (useAsyncCopy &&
315313 targetInfo.getISAFamily () == triton::AMD::ISAFamily::CDNA4) {
314+ unsigned warpSize = targetInfo.getWarpSize ();
316315 return composePaddedLayoutForAsyncCopyCDNA4 (dotOpEnc, srcTy, sharedOrder,
317- useAsyncCopy);
316+ useAsyncCopy, warpSize );
318317 }
319318 return {};
320319}
0 commit comments