Skip to content

Commit 88f8f7d

Browse files
committed
[TLX][AMD] Align TDM descriptor encoding with destination memdesc
TLX kernels emit `amdgpu.async_tdm_copy_global_to_local` directly, bypassing `tt.descriptor_load`. `AssignDescriptorMemoryLayouts` doesn't see TDM ops, so the descriptor settled on the default fallback encoding while the alloc's destination memdesc carried whatever `TLXInsertRequireLayout` picked (e.g. WMMA-tuned `composePaddedLayoutWMMA`). The TDM hardware lowering reads stride from the descriptor and writes into the alloc — the mismatch caused out-of-bounds LDS writes (e.g. 128x128x64 matmul on gfx1250). Add `alignTDMDescriptorEncodings` to AMD `OptimizeDescriptorEncoding`: walk every TDM copy, read its destination memdesc encoding, and rewrite the descriptor's `TensorDescType` to carry the same encoding. Routes the encoding through `updateEncodingForShape` so order/CGA fields stay consistent with the descriptor's block shape. Errors out if two TDM copies share a descriptor with conflicting destination encodings. With the descriptor side now kept in sync, restore dot-aware encoding selection in `anchorTDMRequireLayout`: when `DotConsumerBackward` finds a `tt.dot` consumer, use `composePaddedLayoutWMMA` against the buffer memdesc (which already has CGA layout — the descriptor block type is still un-encoded at this stage). Falls back to the descriptor-shape default for non-dot consumers. Lit tests updated to expect the WMMA-tuned encodings on dot paths (`[128:+8]` for opIdx=0, `[128:+16]` for opIdx=1 transposed); added positive and conflict-error tests for `alignTDMDescriptorEncodings`. Patch from @Hardcode84. Made-with: Cursor
1 parent ac501ff commit 88f8f7d

5 files changed

Lines changed: 191 additions & 40 deletions

File tree

test/TLX/insert-require-layout-tdm.mlir

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
// `tlx-propagate-layout` can rewrite the source `local_alloc` to a
88
// descriptor-compatible padded encoding. When the buffer is consumed by
99
// a `local_load -> tt.dot` chain the WMMA-tuned padded layout is used
10-
// (`composePaddedLayout`); otherwise the descriptor-shape-only fallback
11-
// is used (`buildDefaultTDMDescriptorEncoding`). The dot-path walk in
12-
// the same pass skips TDM-fed buffers so the two anchors don't conflict.
10+
// (`composePaddedLayoutWMMA`); otherwise the descriptor-shape-only
11+
// fallback is used (`buildDefaultTDMDescriptorEncoding`). The
12+
// downstream AMD `OptimizeDescriptorEncoding` pass propagates the
13+
// chosen encoding back to the descriptor's `TensorDescType` so the
14+
// hardware lowering and the alloc agree.
15+
// The dot-path walk in the same pass skips TDM-fed buffers so the two
16+
// anchors don't conflict.
1317

1418
// =============================================================================
1519
// 1. Smallest case: TDM copy with no consumer. Default fallback fires.
@@ -62,9 +66,12 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
6266

6367
// -----
6468
// =============================================================================
65-
// 3. TDM copy feeding tt.dot operand A (opIdx=0). Dot-aware encoding fires.
66-
// composePaddedLayoutWMMA, non-transposed (order=[1,0], opIdx=0):
67-
// padInterval = block_shape[order[0]] = 128, padAmount = 128/16 = 8.
69+
// 3. TDM copy feeding tt.dot operand A (opIdx=0). The WMMA-tuned padded
70+
// encoding from `composePaddedLayoutWMMA` is selected:
71+
// non-transposed (order[0]=1, 1-opIdx=1), padAmount=128/16=8, padInterval
72+
// = max(innerDim=32, bankWrapInterval=128) = 128 -> `[128:+8]`.
73+
// `OptimizeDescriptorEncoding` propagates the same encoding back to the
74+
// descriptor type so the hardware lowering and the alloc agree.
6875
// =============================================================================
6976

7077
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
@@ -90,10 +97,12 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
9097

9198
// -----
9299
// =============================================================================
93-
// 4. TDM copy feeding tt.dot operand B (opIdx=1). Dot-aware encoding fires.
94-
// composePaddedLayoutWMMA, transposed (order=[1,0], opIdx=1):
95-
// padInterval = block_shape[order[0]] = 128.
96-
// padAmount = 2 * ldsParams->instBitWidth / typeBits = 2 * 128 / 16 = 16.
100+
// 4. TDM copy feeding tt.dot operand B (opIdx=1). WMMA-tuned encoding from
101+
// `composePaddedLayoutWMMA`:
102+
// transposed (order[0]=1, 1-opIdx=0), padAmount = 2*instBitWidth/elemBits
103+
// = 2*128/16 = 16 (gfx1250 LDS-trans for fp16 has instBitWidth=128),
104+
// padInterval = max(innerDim=128, bankWrapInterval=128) = 128
105+
// -> `[128:+16]`.
97106
// =============================================================================
98107

99108
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[128:+16] {order = [1, 0], shape = [32, 128]}>
@@ -119,13 +128,14 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
119128

120129
// -----
121130
// =============================================================================
122-
// 5. Conflicting dot consumers on the same TDM-fed buffer fall back to default.
123-
// Two local_loads from the same buffer with different opIdx
124-
// -> findDotConsumer returns nullopt -> default encoding [32:+8].
131+
// 5. Conflicting dot consumers on the same TDM-fed buffer.
132+
// `DotConsumerBackward` widens to `Conflict` (opIdx=0 and opIdx=1 disagree),
133+
// so `findDotConsumer` returns nullopt and the anchor falls back to the
134+
// descriptor-shape-only default `[32:+8]` instead of either WMMA-tuned variant.
125135
// =============================================================================
126136

127137
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[32:+8] {order = [1, 0], shape = [128, 32]}>
128-
// CHECK-NOT: #{{.*}} = #ttg.padded_shared<[128
138+
// CHECK-NOT: #ttg.padded_shared<[128
129139

130140
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, ctaLayout = {warp = [[0, 1], [1, 0]]}, instrShape = [16, 16, 32]}>
131141
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
@@ -204,6 +214,7 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
204214
// the isFedByTDM check the dot-path walk would insert a swizzled-shared
205215
// require_layout that conflicts with the TDM padded encoding. With the
206216
// check, only the TDM padded anchor is inserted; no swizzled anchor.
217+
// Encoding is the WMMA-tuned `[128:+8]` (opIdx=0, [128,32] fp16).
207218
// =============================================================================
208219

209220
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
@@ -389,10 +400,10 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
389400
// same alloc). `findDotConsumer` must walk *up* to the alloc and back
390401
// *down* to find the load — a downstream-only walk from the TDM op's
391402
// buffer would miss it and silently fall back to the default encoding.
403+
// With WMMA-tuned encoding propagation, the anchor uses `[128:+8]`.
392404
// =============================================================================
393405

394406
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
395-
// CHECK-NOT: #{{.*}} = #ttg.padded_shared<[32:+8]
396407

397408
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, ctaLayout = {warp = [[0, 1], [1, 0]]}, instrShape = [16, 16, 32]}>
398409
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
@@ -419,11 +430,10 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
419430
// proper sparse backward dataflow (DotConsumerBackward) handles the
420431
// iter-arg via SparseBackwardDataFlowAnalysis's region-branch support;
421432
// the previous hand-rolled walk would have stopped at the iter-arg
422-
// boundary and missed the dot consumer.
433+
// boundary and missed the dot consumer. WMMA-tuned encoding `[128:+8]`.
423434
// =============================================================================
424435

425436
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
426-
// CHECK-NOT: #{{.*}} = #ttg.padded_shared<[32:+8]
427437

428438
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, ctaLayout = {warp = [[0, 1], [1, 0]]}, instrShape = [16, 16, 32]}>
429439
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
@@ -454,8 +464,10 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
454464
// -----
455465
// =============================================================================
456466
// 17. End-to-end GEMM-shaped pattern: A and B descriptors, two TDM copies,
457-
// dot consumer. Both TDM ops anchor with the WMMA-tuned padded encoding;
458-
// no swizzled-shared anchors from the dot-path walk on TDM-fed buffers.
467+
// dot consumer. Both TDM ops anchor with the WMMA-tuned padded encoding
468+
// (A: opIdx=0 non-transposed -> `[128:+8]`; B: opIdx=1 transposed ->
469+
// `[128:+16]`); no swizzled-shared anchors from the dot-path walk on
470+
// TDM-fed buffers.
459471
// =============================================================================
460472

461473
// CHECK-DAG: #{{.*}} = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: triton-opt -split-input-file --tritonamdgpu-optimize-descriptor-encoding --verify-diagnostics %s
2+
3+
// Test that `alignTDMDescriptorEncodings` rejects two TDM copies on the same
4+
// descriptor that disagree on the destination memdesc encoding. There's no
5+
// principled way to pick one encoding over the other, and silently keeping
6+
// the default would re-introduce the OOB mismatch the pass is meant to
7+
// prevent.
8+
9+
#shared_a = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
10+
#shared_b = #ttg.padded_shared<[32:+8] {order = [1, 0], shape = [128, 32]}>
11+
#smem = #ttg.shared_memory
12+
13+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
14+
tt.func public @tdm_conflicting_destination_encodings(%desc: !tt.tensordesc<128x32xf16>, %m: i32, %k: i32, %p: i32) {
15+
%c0 = arith.constant 0 : i32
16+
%alloc_a = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared_a, #smem, mutable>
17+
%alloc_b = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared_b, #smem, mutable>
18+
%buf_a = ttg.memdesc_index %alloc_a[%c0] : !ttg.memdesc<1x128x32xf16, #shared_a, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared_a, #smem, mutable>
19+
%buf_b = ttg.memdesc_index %alloc_b[%c0] : !ttg.memdesc<1x128x32xf16, #shared_b, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared_b, #smem, mutable>
20+
%tok_a = amdg.async_tdm_copy_global_to_local %desc[%m, %k] into %buf_a, pred = %p : !tt.tensordesc<128x32xf16> -> !ttg.memdesc<128x32xf16, #shared_a, #smem, mutable>
21+
// expected-error @+1 {{TDM copies using the same descriptor require conflicting destination layouts}}
22+
%tok_b = amdg.async_tdm_copy_global_to_local %desc[%m, %k] into %buf_b, pred = %p : !tt.tensordesc<128x32xf16> -> !ttg.memdesc<128x32xf16, #shared_b, #smem, mutable>
23+
tt.return
24+
}
25+
}

test/TritonGPU/amd/amd-optimize-descriptor-encoding.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,54 @@ tt.func public @descriptor_fallback(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: i32,
193193
tt.return
194194
}
195195
}
196+
197+
// -----
198+
// =============================================================================
199+
// alignTDMDescriptorEncodings: TLX-emitted `amdgpu.async_tdm_copy_global_to_local`
200+
// ops are not seen by `AssignDescriptorMemoryLayouts`, so the descriptor would
201+
// otherwise keep the default fallback encoding while the destination memdesc
202+
// carries the TLX-picked (e.g. WMMA-tuned) encoding. The alignment pass copies
203+
// the destination memdesc encoding back to the descriptor's `TensorDescType`
204+
// so the hardware lowering and the alloc agree.
205+
// =============================================================================
206+
207+
#shared = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
208+
#smem = #ttg.shared_memory
209+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
210+
// CHECK-DAG: #[[$PADDED_TDM:.*]] = #ttg.padded_shared<[128:+8] {order = [1, 0], shape = [128, 32]}>
211+
// CHECK-LABEL: @tdm_descriptor_arg_aligns_to_alloc
212+
// CHECK-SAME: %[[DESC:.*]]: !tt.tensordesc<128x32xf16, #[[$PADDED_TDM]]>
213+
tt.func public @tdm_descriptor_arg_aligns_to_alloc(%desc: !tt.tensordesc<128x32xf16>, %m: i32, %k: i32, %p: i32) {
214+
%c0 = arith.constant 0 : i32
215+
%alloc = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
216+
%buf = ttg.memdesc_index %alloc[%c0] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
217+
// CHECK: amdg.async_tdm_copy_global_to_local %[[DESC]][{{.*}}] into {{.*}} : !tt.tensordesc<128x32xf16, #[[$PADDED_TDM]]>
218+
%tok = amdg.async_tdm_copy_global_to_local %desc[%m, %k] into %buf, pred = %p : !tt.tensordesc<128x32xf16> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
219+
tt.return
220+
}
221+
}
222+
223+
// -----
224+
// =============================================================================
225+
// alignTDMDescriptorEncodings: descriptor created by `tt.make_tensor_descriptor`
226+
// (op-result, not function arg) is updated in-place. The local `make_tensor_descriptor`
227+
// op's result type is rewritten and downstream TDM op picks up the new desc type.
228+
// =============================================================================
229+
230+
#shared = #ttg.padded_shared<[128:+16] {order = [1, 0], shape = [32, 128]}>
231+
#smem = #ttg.shared_memory
232+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
233+
// CHECK-DAG: #[[$PADDED_TDM:.*]] = #ttg.padded_shared<[128:+16] {order = [1, 0], shape = [32, 128]}>
234+
// CHECK-LABEL: @tdm_local_descriptor_aligns_to_alloc
235+
tt.func public @tdm_local_descriptor_aligns_to_alloc(%ptr: !tt.ptr<f16>, %sz0: i32, %sz1: i32, %s0: i64, %k: i32, %n: i32, %p: i32) {
236+
%c0 = arith.constant 0 : i32
237+
%c1_i64 = arith.constant 1 : i64
238+
// CHECK: tt.make_tensor_descriptor {{.*}} : !tt.ptr<f16>, !tt.tensordesc<32x128xf16, #[[$PADDED_TDM]]>
239+
%desc = tt.make_tensor_descriptor %ptr, [%sz0, %sz1], [%s0, %c1_i64] : !tt.ptr<f16>, !tt.tensordesc<32x128xf16>
240+
%alloc = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
241+
%buf = ttg.memdesc_index %alloc[%c0] : !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
242+
// CHECK: amdg.async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<32x128xf16, #[[$PADDED_TDM]]>
243+
%tok = amdg.async_tdm_copy_global_to_local %desc[%k, %n] into %buf, pred = %p : !tt.tensordesc<32x128xf16> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
244+
tt.return
245+
}
246+
}

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeDescriptorEncoding.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,61 @@ static void computeDesiredEncodingAttr(mlir::ModuleOp &m) {
123123
}
124124
}
125125

126+
// TLX kernels emit `amdgpu.async_tdm_copy_global_to_local` directly, bypassing
127+
// `tt.descriptor_load`. The destination memdesc carries the encoding chosen by
128+
// TLX (e.g. WMMA-tuned `composePaddedLayout` when feeding `tt.dot`). Without
129+
// any propagation, the descriptor's `TensorDescType` keeps the fallback
130+
// encoding from `AssignDescriptorMemoryLayouts`, while the alloc gets the
131+
// TLX-picked encoding. The TDM hardware lowering in `LoadStoreOpToLLVM` reads
132+
// stride from the descriptor type but writes into the alloc — a stride
133+
// mismatch causes out-of-bounds LDS writes.
134+
//
135+
// This pass copies the destination memdesc encoding back to the descriptor
136+
// type so the two sides agree by construction. If multiple TDM copies share a
137+
// descriptor with conflicting destination encodings, we error out (no good
138+
// way to pick one over the other; TLX kernels currently never hit this).
139+
static LogicalResult alignTDMDescriptorEncodings(mlir::ModuleOp &m) {
140+
llvm::DenseMap<Value, Attribute> descToEncoding;
141+
WalkResult result =
142+
m.walk(
143+
[&](tt::amdgpu::AsyncTDMCopyGlobalToLocalOp copy) {
144+
auto memDescTy = cast<ttg::MemDescType>(copy.getResult().getType());
145+
Attribute encoding = memDescTy.getEncoding();
146+
Value desc = copy.getDesc();
147+
148+
auto [it, inserted] = descToEncoding.try_emplace(desc, encoding);
149+
if (!inserted && it->second != encoding) {
150+
copy.emitError()
151+
<< "TDM copies using the same descriptor require conflicting "
152+
"destination layouts";
153+
return WalkResult::interrupt();
154+
}
155+
return WalkResult::advance();
156+
});
157+
if (result.wasInterrupted())
158+
return failure();
159+
160+
for (auto [desc, encoding] : descToEncoding) {
161+
auto descTy = cast<tt::TensorDescType>(desc.getType());
162+
auto blockTy = descTy.getBlockType();
163+
// Adjust order/CGA fields of paddedEncoding/swizzled/nvmma to the
164+
// descriptor's block shape so a future rank-reducing TDM doesn't desync.
165+
auto sharedEnc = cast<ttg::SharedEncodingTrait>(encoding);
166+
Attribute fittedEnc =
167+
ttg::updateEncodingForShape(desc.getDefiningOp(), sharedEnc, blockTy);
168+
desc.setType(tt::TensorDescType::get(blockTy.getShape(),
169+
blockTy.getElementType(), fittedEnc));
170+
}
171+
172+
auto ctx = m.getContext();
173+
for (auto func : m.getOps<tt::FuncOp>()) {
174+
SmallVector<Type> argTypes(func.getBlocks().front().getArgumentTypes());
175+
SmallVector<Type> resultTypes(func.getResultTypes());
176+
func.setFunctionType(FunctionType::get(ctx, argTypes, resultTypes));
177+
}
178+
return success();
179+
}
180+
126181
class AMDGPUAssignDescriptorMemoryLayouts
127182
: public ttg::AssignDescriptorMemoryLayouts {
128183
public:
@@ -184,6 +239,11 @@ class TritonAMDGPUOptimizeDescriptorEncodingPass
184239
AMDGPUAssignDescriptorMemoryLayouts assignMemoryLayouts;
185240
assignMemoryLayouts.assignMemoryLayouts(m);
186241

242+
if (failed(alignTDMDescriptorEncodings(m))) {
243+
signalPassFailure();
244+
return;
245+
}
246+
187247
// Remove temporary discardable attributes used during encoding assignment
188248
for (auto f : m.getOps<tt::FuncOp>()) {
189249
f.walk([](tt::DescriptorLoadOp load) {

third_party/tlx/dialect/lib/Transforms/InsertRequireLayout.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,10 @@ static void applyRequireLayout(ttg::SwizzledSharedEncodingAttr encoding,
300300
return;
301301

302302
// Defer to the TDM anchor for buffers fed by `amdgpu.async_tdm_*`. The
303-
// TDM walk uses the WMMA-tuned padded encoding from
304-
// `composePaddedLayout` (which is correct for both the TDM op and the
305-
// local_load -> dot path); inserting a sibling swizzled anchor here
306-
// would conflict with that constraint and widen the lattice to unknown.
303+
// TDM walk picks a padded encoding that's compatible with the descriptor
304+
// (and dot-aware when applicable); inserting a sibling swizzled anchor
305+
// here would conflict with that constraint and widen the lattice to
306+
// unknown.
307307
if (isFedByTDM(loadMemDesc))
308308
return;
309309

@@ -570,27 +570,30 @@ static void anchorTDMRequireLayout(amdgpu::AsyncTDMCopyGlobalToLocalOp tdmOp,
570570

571571
auto cgaLayout = ttg::CGAEncodingAttr::get1CTALayout(buf.getContext(), rank);
572572

573-
// First try the dot-operand-aware path: when the buffer is consumed by a
574-
// `local_load -> tt.dot` chain, the WMMA-tuned padded encoding from
575-
// `composePaddedLayout` is required for the local_load lowering to
576-
// satisfy the dot's operand encoding constraints. Otherwise fall back
577-
// to the descriptor-shape-only default.
573+
// Prefer the WMMA-tuned padded encoding when the buffer feeds a
574+
// `tt.dot`: `composePaddedLayout` picks intervals/paddings to avoid bank
575+
// conflicts on the `local_load -> tt.dot` lowering. Fall back to the
576+
// descriptor-shape-only default for non-dot consumers.
577+
//
578+
// Using a dot-tuned encoding here is safe because the AMD
579+
// `OptimizeDescriptorEncoding` pass walks TDM copies and propagates this
580+
// encoding back to the descriptor's `TensorDescType`, so the hardware
581+
// (which reads stride from the descriptor) and the alloc (which uses
582+
// this encoding to size the LDS region) agree by construction.
578583
Attribute encoding;
579-
if (auto dotInfo = findDotConsumer(buf, solver)) {
580-
auto modOp = tdmOp->getParentOfType<ModuleOp>();
581-
auto archAttr = mlir::getAMDArch(modOp);
582-
if (archAttr) {
583-
triton::AMD::TargetInfo targetInfo(archAttr->str());
584-
auto srcTy = cast<ttg::TensorOrMemDesc>(bufType);
585-
if (auto padded = composePaddedLayout(targetInfo, dotInfo->opIdx,
586-
dotInfo->kWidth, srcTy, order))
587-
encoding = padded;
588-
}
584+
if (auto info = findDotConsumer(buf, solver)) {
585+
auto archStr = getAMDArch(tdmOp->getParentOfType<ModuleOp>());
586+
auto targetInfo = tt::AMD::TargetInfo(archStr.value_or("").str());
587+
// Use bufType (MemDescType) instead of the descriptor's block type:
588+
// bufType carries the alloc's CGA layout, while the descriptor type
589+
// is still un-encoded at this point (OptimizeDescriptorEncoding runs
590+
// later and is what propagates the encoding back to the descriptor).
591+
encoding = composePaddedLayout(targetInfo, info->opIdx, info->kWidth,
592+
cast<ttg::TensorOrMemDesc>(bufType), order);
589593
}
590-
if (!encoding) {
594+
if (!encoding)
591595
encoding = buildDefaultTDMDescriptorEncoding(buf.getContext(), shape, order,
592596
cgaLayout, elementType);
593-
}
594597
if (!encoding)
595598
return;
596599

0 commit comments

Comments
 (0)