Skip to content

Commit 8b032cd

Browse files
[AMD] relax padded layout heuristics to smaller block size (#9074)
leveraging wrap around due to padding, we can still get bank conflict free padded share layout when block size is smaller than 16KB. take Mx64xbf16, k contiguous, kWidth=8, mfma16x16 for example: (rX stands for row X), the minimal block size can be 32x64. padding here is set to 16 elements (32 bytes) to avoid bank conflicts we can pack r0,r4,r8,r12,r16,r20,r24,r28to compose a contiguous tile ``` r0[0+], r0[8+], r1[0+], r1[8+], r2[0+], r2[8+], r3[0+], r3[8+], r4[0+], r4[8+], r5[0+], r5[8+], r6[0+], r6[8+], r7[0+], r7[8+], r8[0+], r8[8+], ``` in LDS, the rows are arranged as below ``` r0, r4, r8, r12, r16, r20, r24, r28 pad, r1, r5, r9, r13, r17, r21, r25 r29, pad, r2, r6, r10, r14, r18, r22, r26, r30, pad, r3, r7, r11, r15, r19, r23, r27, r31 ```
1 parent 3248e08 commit 8b032cd

2 files changed

Lines changed: 124 additions & 97 deletions

File tree

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
986986

987987
// ASYNC-NOT: ttg.swizzled_shared
988988
// ASYNC: [[PADDED_ENC:#.*]] = #ttg.padded_shared
989-
// ASYNC-SAME{LITERAL}: {offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [32, 0], [16, 0], [1, 0], [2, 0], [4, 0], [8, 0], [64, 0]], block = []}
989+
// ASYNC-SAME{LITERAL}: {offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [16, 0], [32, 0], [1, 0], [2, 0], [4, 0], [8, 0], [64, 0]], block = []}
990990
// ASYNC-NOT: ttg.padded_shared
991991
// ASYNC-NOT: ttg.swizzled_shared
992992

@@ -1139,4 +1139,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
11391139
}
11401140
}
11411141

1142+
// -----
1143+
1144+
// small Block size 32x64
1145+
1146+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
1147+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [16, 16, 32], isTransposed = true}>
1148+
1149+
// ASYNC-NOT: ttg.swizzled_shared
1150+
// ASYNC{LITERAL}: padded_shared<[512:+16] {offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [4, 0], [8, 0], [16, 0], [1, 0], [2, 0]]
1151+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
1152+
// ASYNC-LABEL: loop_padding_block_size_small
1153+
tt.func public @loop_padding_block_size_small(%arg0: i32, %arg1: tensor<32x64x!tt.ptr<f16>, #blocked> {tt.constancy = dense<1> : tensor<2xi32>, tt.contiguity = dense<[1, 8]> : tensor<2xi32>, tt.divisibility = dense<[1, 16]> : tensor<2xi32>}, %arg2: tensor<32x64x!tt.ptr<f16>, #mma>) {
1154+
%c1_i32 = arith.constant 1 : i32
1155+
%c0_i32 = arith.constant 0 : i32
1156+
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #mma>
1157+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
1158+
%0 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %cst) -> (tensor<32x64xf16, #mma>) : i32 {
1159+
%1 = tt.load %arg1 : tensor<32x64x!tt.ptr<f16>, #blocked>
1160+
%2 = ttg.convert_layout %1 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
1161+
%3 = tt.dot %2, %cst_0, %arg4 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x64xf16, #mma>
1162+
scf.yield %3 : tensor<32x64xf16, #mma>
1163+
}
1164+
tt.store %arg2, %0 : tensor<32x64x!tt.ptr<f16>, #mma>
1165+
tt.return
1166+
}
1167+
}
1168+
1169+
11421170
// End of negative tests for padding on gfx950

third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp

Lines changed: 95 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -124,26 +124,28 @@ int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
124124
// - Padding intervals must be multiples of 1024 bytes for 16-byte loads.
125125
// To avoid bank conflicts when reading tensors in MFMA layout, we stagger
126126
// continuous rows (non contig dimension) by adding padding that shifts their
127-
// start addresses to different shared memory banks. Generally it's enough to
128-
// pad 16 continous rows (see exception below for mfma32 kContig). Therefore, we
129-
// implement a linear mapping from logical tensor elements to shared memory
130-
// offsets that:
131-
// - Strides 16 consecutive rows by 1024 bytes in shared memory.
132-
// - Fills "holes" by rows which are a multiple of 16
133-
// For example, if each row is 256 bytes, four rows are required to fill the
134-
// hole. The resulting reordering of rows in logical order is:
135-
// [r0, r16, r32, r48, r1, row17, row33, row49, row2, row18, ...]
136-
// Corresponding byte offsets for these rows are:
137-
// [0, 256, 512, 768, 1024, ...]
138-
// This approach naturally generalizes to other row sizes. For example, with
139-
// 128-byte rows:
140-
// Logical row order: [r0, r16, r32, r48, r64, r80, r96, r112, r1, r17, ...]
141-
// Byte offsets: [0, 128, 256, 384, ..., 1024, ...]
142-
// Since padding is applied in groups of 16 rows, the total data size for this
143-
// layout must be at least 16 KB (16 * 1024 bytes).
127+
// start addresses to different shared memory banks.
128+
// take Mx64xbf16, k contiguous, kWidth=8, for example: (rX stands for row X)
129+
// padding here is set to 16 elements (32 bytes) to avoid bank conflicts
130+
// we can pack r0,r4,r8,r12,r16,r20,r24,r28 to compose a contiguous tile
131+
// r0[0:8), r0[8:16),
132+
// r1[0:8), r1[8:16),
133+
// r2[0:8), r2[8:16),
134+
// r3[0:8), r3[8:16),
135+
// r4[0:8), r4[8:16),
136+
// r5[0:8), r5[8:16),
137+
// r6[0:8), r6[8:16),
138+
// r7[0:8), r7[8:16),
139+
// r8[0:8), r8[8:16),
140+
// when composing padded layout, we first assemble the rows that are continuous.
141+
// in LDS, the rows are arranged as below
142+
// r0, r4, r8, r12, r16, r20, r24, r28
143+
// pad, r1, r5, r9, r13, r17, r21, r25
144+
// r29, pad, r2, r6, r10, r14, r18, r22
145+
// r26, r30, pad, r3 ....
144146
ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
145147
ttg::DotOperandEncodingAttr dotOpEnc, ttg::TensorOrMemDesc srcTy,
146-
ArrayRef<unsigned> sharedOrder, bool useAsyncCopy) {
148+
ArrayRef<unsigned> sharedOrder, bool useAsyncCopy, unsigned warpSize) {
147149
auto *ctx = srcTy.getContext();
148150

149151
// NYI: padded layouts for tt.load/local_write which is more flexible
@@ -165,21 +167,12 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
165167

166168
unsigned bitWidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType());
167169
unsigned elemByteWidth = std::max(bitWidth / 8u, 1u);
168-
auto loadBytes = shape[0] * shape[1] * elemByteWidth;
169-
if (loadBytes < 16384) {
170-
return {};
171-
}
172170

173171
// NYI: dtypes != 16bit
174172
if (elemByteWidth != 2) {
175173
return {};
176174
}
177175

178-
// NYI: requires different stride factor since we stride by 16 rows
179-
if (std::min(shape[0], shape[1]) < 16) {
180-
return {};
181-
}
182-
183176
auto operandIdx = dotOpEnc.getOpIdx();
184177
auto kWidth = dotOpEnc.getKWidth();
185178
int kDimIndex = operandIdx == 0 ? 1 : 0;
@@ -203,87 +196,94 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
203196

204197
// Determine row(contig) size
205198
unsigned contigDim = isKContig ? kDim : nonKDim;
199+
unsigned nonContigDim = isKContig ? nonKDim : kDim;
200+
201+
// padding to avoid bank conflict
202+
// For ds_read_b128. Lanes access LDS in 4 pairs of 16 lanes. we have 64 banks
203+
// and each lane loads 4 banks. These lane groups are:
204+
// 1: 0-3, 12-15, 20-23, 24-27
205+
// 2: 4-7, 8-11, 16-19, 28-31
206+
// The upper half of the lanes follow the same pattern.
207+
// For ds_read_b64, it splits conseuctive lanes into 2 groups which access LDS
208+
// one after another
209+
unsigned padding = 0;
210+
if (isKContig) {
211+
padding = mfmaNonKDim == 16 ? (kWidth * 2) : kWidth;
212+
} else {
213+
padding = mfmaNonKDim == 16 ? 16 : 32;
214+
}
215+
constexpr unsigned vecSize = 8; // in favor of dwordX4
216+
unsigned contigLanes = contigDim / vecSize;
217+
unsigned wrap = std::min(contigDim, 128u) / padding;
218+
unsigned requiredDim = warpSize / contigLanes * wrap;
219+
if (nonContigDim < requiredDim) {
220+
return {};
221+
}
206222

207-
// Clamp contigSize to 1024 bytes to have space for at least 16 rows per sub
208-
// tile (16KB) and simply repeat the tile to the full tensor size.
209-
contigDim = std::min(1024 / elemByteWidth, contigDim);
223+
// Use 16 rows wrap if block large enough
224+
bool useBestWrap = false;
225+
unsigned bestWrap = 16;
226+
if (nonContigDim >= warpSize / contigLanes * bestWrap && bestWrap > wrap) {
227+
useBestWrap = true;
228+
wrap = bestWrap;
229+
}
210230

211231
// We create linear bases mapping from [contigDim, nonContigDim] -> offset,
212-
// representing the row reordering as described above
213232
std::vector<std::vector<int>> bases;
233+
214234
// Keep contigSize numbers of elments contiguous in shared memory
215235
for (int elemLog2 = 0; elemLog2 < llvm::Log2_32(contigDim); elemLog2++)
216236
bases.push_back({1 << elemLog2, 0});
217237

218-
// Add strided rows (by 16) to pad to 1024bytes
219-
auto requiredNumBases = llvm::Log2_32(1024U / elemByteWidth);
220-
for (int rowBase = llvm::Log2_32(16); bases.size() < requiredNumBases;
238+
// Add rows strided which has the same start offset
239+
unsigned paddingInterval = warpSize * vecSize;
240+
unsigned requiredNumBases = llvm::Log2_32(paddingInterval);
241+
int rowBase = 0;
242+
for (rowBase = llvm::Log2_32(wrap); bases.size() < requiredNumBases;
221243
rowBase++)
222244
bases.push_back({0, 1 << rowBase});
223245

224-
// Add rows 1..16 afterwards to complete the tile
225-
for (int rowLog2 = 0; rowLog2 < llvm::Log2_32(16); rowLog2++)
246+
// Add rows [0, wrap]
247+
for (int rowLog2 = 0; rowLog2 < llvm::Log2_32(wrap); rowLog2++)
226248
bases.push_back({0, 1 << rowLog2});
227249

228-
// Compute required padding (in bytes) to avoid conflicts when accessing rows
229-
unsigned paddingBytes = 0;
230-
231-
// To compute the required amount of padding to avoid bank conflicts we look
232-
// at the number of contiguous bytes loaded for a single row this directly
233-
// gives us the padding we require. Note for contigBytesPerLane == 16 we use a
234-
// different mfma layout (wide) compared to contigBytesPerLane == 8 (narrow)
235-
int contigBytesPerLane = kWidth * elemByteWidth;
236-
bool useWideLayout = contigBytesPerLane == 16;
237-
if (isKContig) {
238-
// For wide layouts we will use ds_read_b128. Lanes access LDS
239-
// (bank conflicts) in 4 pairs of 16 lanes since we have 64 banks and each
240-
// lane loads 4 banks. These (lane)groups are:
241-
// 1: 0-3, 12-15, 20-23, 24-27
242-
// 2: 4-7, 8-11, 16-19, 28-31
243-
// The upper half of the lanes follow the same pattern.
244-
// For narrow layouts we will use ds_read_b64 which splits conseuctive
245-
// lanes into 2 groups which access LDS one after another
246-
247-
if (mfmaNonKDim == 16) {
248-
// For wide layouts lane groups read 32 contiguous bytes
249-
// For narrow layouts lane groups load 8 contiguous bytes
250-
paddingBytes = useWideLayout ? 32 : 8;
251-
}
250+
// Add remaining rows
251+
for (; rowBase < llvm::Log2_32(nonContigDim); rowBase++)
252+
bases.push_back({0, 1 << rowBase});
252253

253-
if (mfmaNonKDim == 32) {
254-
// For mfma32 32 lanes read 32 continuous rows. So for narrow layouts we
255-
// read 8 contiguous bytes and for wide layouts 16 bytes.
256-
paddingBytes = useWideLayout ? 16 : 8;
257-
258-
// For narrow layouts we need to shift every 16th row to the other half of
259-
// shared memory banks to read from all banks. For the wide layout we need
260-
// to ensure every 16th rows start at the same bank so lane groups access
261-
// different banks. This is done by swapping the bases representing offset
262-
// 256 (64banks) for wide layouts or 128 (32banks) for narrow layouts with
263-
// the base of the "16th" row which is after log2(contigDim) bases.
264-
int offsetBytes = useWideLayout ? 256 : 128;
265-
int offsetIndex = llvm::Log2_32(offsetBytes);
266-
int row16Index = llvm::Log2_32(contigDim);
267-
assert(row16Index < bases.size());
268-
assert(offsetIndex < bases.size());
269-
std::swap(bases[offsetIndex], bases[row16Index]);
270-
}
271-
} else {
272-
if (mfmaNonKDim == 16) {
273-
// For mfma16 lane groups read 32 contiguous bytes
274-
paddingBytes = 32;
275-
if (useWideLayout) {
276-
// For for the wide layout lane groups wrap at row 8 so we have to
277-
// exchange row4 and row8 to avoid conflicts (last two bases)
278-
std::swap(bases[bases.size() - 1], bases[bases.size() - 2]);
279-
}
280-
} else if (mfmaNonKDim == 32) {
281-
// For mfma32 lane groups read 64 contiguous bytes
282-
paddingBytes = 64;
254+
// Fixup for nonKContig and mfma16
255+
if (!isKContig && mfmaNonKDim == 16) {
256+
unsigned row4 = 0;
257+
unsigned row8 = 0;
258+
for (unsigned i = 0; i < bases.size(); i++) {
259+
if (bases[i][1] == 8)
260+
row8 = i;
261+
if (bases[i][1] == 4)
262+
row4 = i;
283263
}
264+
assert(row4 != 0 && row8 != 0);
265+
// lane groups wrap at row8, so we have to exchange
266+
// row4 and row8 to avoid bank conflict
267+
std::swap(bases[row4], bases[row8]);
284268
}
285269

286-
assert(paddingBytes != 0);
270+
// Fixup for KContig and mfma32 when reordered rows can not fit in 64banks
271+
if (isKContig && mfmaNonKDim == 32 && useBestWrap && kDim < 128) {
272+
bool useWideLayout = kWidth == 8;
273+
274+
// For narrow layouts we need to shift every 16th row to the other half of
275+
// shared memory banks to read from all banks. For the wide layout we need
276+
// to ensure every 16th rows start at the same bank so lane groups access
277+
// different banks. This is done by swapping the bases representing offset
278+
// 256 (64banks) for wide layouts or 128 (32banks) for narrow layouts with
279+
// the base of the "16th" row which is after log2(contigDim) bases.
280+
int offsetBytes = useWideLayout ? 256 : 128;
281+
int offsetIndex = llvm::Log2_32(offsetBytes);
282+
int row16Index = llvm::Log2_32(contigDim);
283+
assert(row16Index < bases.size());
284+
assert(offsetIndex < bases.size());
285+
std::swap(bases[offsetIndex], bases[row16Index]);
286+
}
287287

288288
// Swap bases to match srcTy dimension order
289289
if ((isKContig && kDimIndex == 1) || (!isKContig && kDimIndex == 0)) {
@@ -300,10 +300,8 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
300300
linearComponent = triton::gpu::combineCtaCgaWithShape(
301301
linearComponent, cgaLayout, srcTy.getShape());
302302

303-
unsigned paddingInterval = 1024 / elemByteWidth;
304-
unsigned paddingInElems = paddingBytes / elemByteWidth;
305-
return ttg::PaddedSharedEncodingAttr::get(
306-
ctx, {{paddingInterval, paddingInElems}}, std::move(linearComponent));
303+
return ttg::PaddedSharedEncodingAttr::get(ctx, {{paddingInterval, padding}},
304+
std::move(linearComponent));
307305
}
308306

309307
ttg::PaddedSharedEncodingAttr
@@ -313,8 +311,9 @@ composePaddedLayout(const tt::AMD::TargetInfo &targetInfo,
313311
bool useAsyncCopy) {
314312
if (useAsyncCopy &&
315313
targetInfo.getISAFamily() == triton::AMD::ISAFamily::CDNA4) {
314+
unsigned warpSize = targetInfo.getWarpSize();
316315
return composePaddedLayoutForAsyncCopyCDNA4(dotOpEnc, srcTy, sharedOrder,
317-
useAsyncCopy);
316+
useAsyncCopy, warpSize);
318317
}
319318
return {};
320319
}

0 commit comments

Comments
 (0)