Skip to content

Commit 8f34e7f

Browse files
committed
[AMD][TDM] Add lit tests for warp_bases and rename to "partial TDM copy"
Add verifier negative tests (wrong size, non-contiguous prefix, greedy mismatch) and lowering tests (predication logic, partitioned layout instruction count) for the warp_bases attribute. Rename "warp specialization" to "partial TDM copy" in all TDM warp_bases-related comments and docs to better describe the mechanism.
1 parent a6f4e3a commit 8f34e7f

7 files changed

Lines changed: 104 additions & 12 deletions

File tree

python/triton/experimental/gluon/language/amd/gfx1250/tdm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.
144144

145145

146146
def _validate_warp_bases(warp_bases, block_shape, num_warps):
147-
"""Validate warp_bases for TDM warp specialization.
147+
"""Validate warp_bases for partial TDM copy.
148148
149149
warp_bases must be log2(num_warps) entries where the non-zero entries form
150150
a contiguous prefix matching the greedy distribution for block_shape over
@@ -214,7 +214,7 @@ def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tenso
214214
dest (shared_memory_descriptor): the shared memory destination to store the loaded data.
215215
pred (int, optional): Predicate to enable or disable the load. Defaults to 1.
216216
mbarrier (shared_memory_descriptor, optional): The barrier object to signal "arrive" on.
217-
warp_bases (List[List[int]], optional): Per-bit warp-to-offset mapping for TDM warp specialization.
217+
warp_bases (List[List[int]], optional): Per-bit warp-to-offset mapping for partial TDM copy.
218218
Each entry maps one bit of warpId to an element offset in the tensor coordinate space.
219219
A zero basis means that bit contributes no offset (duplicate warp, gets pred=0).
220220
"""

test/Conversion/amd/tritongpu_tdm_to_llvm.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,55 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
234234
tt.return
235235
}
236236
}
237+
238+
// -----
239+
240+
// Partial TDM copy: 4 active warps out of 8, verify predication logic
241+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
242+
#smem = #ttg.shared_memory
243+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
244+
// CHECK-LABEL: tdm_load_warp_bases_predication
245+
tt.func public @tdm_load_warp_bases_predication(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
246+
%c_shape = arith.constant 256 : i32
247+
%c_stride0 = arith.constant 256 : i64
248+
%c_stride1 = arith.constant 1 : i64
249+
%c_offset = arith.constant 0 : i32
250+
%c_pred = arith.constant 1 : i32
251+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <256x64xf16, #shared>
252+
%1 = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable>
253+
// warp_bases for 4 active warps: pred = user_pred AND (warpId < 4)
254+
// CHECK-DAG: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
255+
// CHECK: %[[IS_ACTIVE:.*]] = llvm.icmp "ult" %{{.*}}, %[[C4]] : i32
256+
// CHECK: %[[LAYOUT_PRED:.*]] = llvm.select %[[IS_ACTIVE]], %{{.*}}, %{{.*}} : i1, i32
257+
// CHECK: llvm.and %{{.*}}, %[[LAYOUT_PRED]] : i32
258+
// CHECK: "llvm.amdgcn.tensor.load.to.lds"
259+
%2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred {warp_bases = array<i64: 64, 0, 128, 0, 0, 0>} : !tt.tensordesc<256x64xf16, #shared> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable>
260+
tt.return
261+
}
262+
}
263+
264+
// -----
265+
266+
// Partial TDM copy with partitioned layout: effectiveWarps controls TDM instruction count.
267+
// Without warp_bases (4 warps), all 4 logical pieces fit in 1 instruction.
268+
// With warp_bases for 2 active warps, gcd(2,4)=2 → ceil(4/2)=2 instructions.
269+
#shared_inner = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
270+
#partitioned = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 2, partitionDim = 0, partitionLayout = #shared_inner}>
271+
#smem = #ttg.shared_memory
272+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
273+
// CHECK-LABEL: tdm_load_warp_bases_partitioned
274+
tt.func public @tdm_load_warp_bases_partitioned(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
275+
%c_shape = arith.constant 256 : i32
276+
%c_stride0 = arith.constant 256 : i64
277+
%c_stride1 = arith.constant 1 : i64
278+
%c_offset = arith.constant 0 : i32
279+
%c_pred = arith.constant 1 : i32
280+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <128x16xf16, #partitioned>
281+
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable>
282+
// 2 active warps on partitioned layout → 2 TDM instructions
283+
// CHECK: llvm.icmp "ult"
284+
// CHECK-COUNT-2: "llvm.amdgcn.tensor.load.to.lds"
285+
%2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred {warp_bases = array<i64: 64, 0, 0, 0>} : !tt.tensordesc<128x16xf16, #partitioned> -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable>
286+
tt.return
287+
}
288+
}

test/TritonGPU/amd/invalid.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
282282
tt.return
283283
}
284284
}
285+
286+
// -----
287+
288+
// warp_bases validation tests
289+
#shared_wb = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
290+
#smem_wb = #ttg.shared_memory
291+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
292+
tt.func @warp_bases_wrong_size(
293+
%tensorDesc: !tt.tensordesc<256x64xf16>,
294+
%memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>,
295+
%pred: i32
296+
) {
297+
%c0 = arith.constant 0 : i32
298+
// expected-error @+1 {{warp_bases must have log2(num_warps) * ndim = 6 elements, got 4}}
299+
%0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_bases = array<i64: 64, 0, 128, 0>} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>
300+
tt.return
301+
}
302+
303+
tt.func @warp_bases_non_contiguous_prefix(
304+
%tensorDesc: !tt.tensordesc<256x64xf16>,
305+
%memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>,
306+
%pred: i32
307+
) {
308+
%c0 = arith.constant 0 : i32
309+
// expected-error @+1 {{warp_bases non-zero entries must form a contiguous prefix; found non-zero basis at bit 1 after a zero basis}}
310+
%0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_bases = array<i64: 0, 0, 64, 0, 0, 0>} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>
311+
tt.return
312+
}
313+
314+
tt.func @warp_bases_greedy_mismatch(
315+
%tensorDesc: !tt.tensordesc<256x64xf16>,
316+
%memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>,
317+
%pred: i32
318+
) {
319+
%c0 = arith.constant 0 : i32
320+
// expected-error @+1 {{warp_bases mismatch at bit 0 dim 0: expected 64 but got 0; non-zero bases must match the greedy distribution for block_shape over active_warps=4}}
321+
%0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_bases = array<i64: 0, 32, 0, 64, 0, 0>} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>
322+
tt.return
323+
}
324+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local",
751751
The operation can also take an optional 64bit LDS barrier address, in which case
752752
it sends an "LDS atomic arrive" to signal its completion.
753753

754-
`warp_bases` is an optional attribute for TDM warp specialization.
754+
`warp_bases` is an optional attribute for partial TDM copy.
755755
Each entry maps one bit of warpId to an element offset in the tensor
756756
coordinate space. A `[0, ..., 0]` basis means that bit of warpId
757757
contributes no offset (degenerate / duplicate warp). Duplicate warps

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA,
534534
// Fill TDM descriptor for regular load/store operations (1D-5D tensors).
535535
// activeWarps: number of warps that actually issue TDM copies (power of two,
536536
// <= numWarps). Warps with warpId >= activeWarps get pred=0 (hardware no-op).
537-
// A value of 0 means all warps are active (no warp specialization).
537+
// A value of 0 means all warps are active (no partial TDM copy).
538538
void fillTDMDescriptor(
539539
RewriterBase &rewriter, Location loc,
540540
const LLVMTypeConverter *typeConverter, Type elementType,
@@ -584,7 +584,7 @@ void fillTDMDescriptor(
584584
: std::nullopt,
585585
numDims);
586586

587-
// When warp specialization is active, the per-warp block shape differs from
587+
// When partial TDM copy is active, the per-warp block shape differs from
588588
// what createTDMDescriptor encoded (which used numWarps). Re-encode the
589589
// correct per-warp tile dimensions based on warpsPerCTA (from activeWarps).
590590
if (activeWarps > 0) {
@@ -737,7 +737,7 @@ void fillTDMDescriptor(
737737
Value globalAddr = b.ptrtoint(i64_ty, srcPtr);
738738
Value ldsAddr = b.ptrtoint(i32_ty, dstPtr);
739739

740-
// Combine user predicate with layout predicate for warp specialization.
740+
// Combine user predicate with layout predicate for partial TDM copy.
741741
// Duplicate warps (warpId >= activeWarps) get pred=0 (hardware no-op).
742742
if (activeWarps > 0 && activeWarps < numWarps) {
743743
Value isActive = b.icmp_ult(warpId, b.i32_val(activeWarps));
@@ -1105,7 +1105,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
11051105
activeWarps = 1 << activeCount;
11061106
}
11071107

1108-
// When warp specialization is active, compute the warp distribution based
1108+
// When partial TDM copy is active, compute the warp distribution based
11091109
// on activeWarps instead of numWarps.
11101110
int effectiveWarps = (activeWarps > 0) ? activeWarps : numWarps;
11111111

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
4444
// For partitioned shared memory, dstPtrs contains multiple base pointers and
4545
// the correct one is selected based on sharedLayout's partition dimension.
4646
// activeWarps: number of warps that actually issue TDM copies (power of two,
47-
// <= numWarps). 0 means all warps are active (no warp specialization).
47+
// <= numWarps). 0 means all warps are active (no partial TDM copy).
4848
void fillTDMDescriptor(
4949
RewriterBase &rewriter, Location loc,
5050
const LLVMTypeConverter *typeConverter, Type elementType,

third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def gemm_tdm_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
100100

101101

102102
# ---------------------------------------------------------------------------
103-
# TDM warp-specialized variant: only a subset of warps issue TDM copies.
103+
# Partial TDM copy variant: only a subset of warps issue TDM copies.
104104
# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
105105
# ---------------------------------------------------------------------------
106106

@@ -194,7 +194,7 @@ def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
194194
# ---------------------------------------------------------------------------
195195

196196
def _compute_tdm_warp_bases(block_shape, num_warps, active_warps):
197-
"""Compute warp_bases for TDM specialization with the given active warp count.
197+
"""Compute warp_bases for partial TDM copy with the given active warp count.
198198
199199
Returns a tuple of tuples suitable for passing as a constexpr.
200200
"""
@@ -311,7 +311,7 @@ def test_runtime_gemm_tdm_specialized_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_B
311311
num_warps = 8
312312
WARP_BASES = [(0, 1), (1, 0), (2, 0)]
313313

314-
# 4-warp TDM specialization: warps 4-7 duplicate 0-3 (pred=0, hardware no-op)
314+
# 4-warp partial TDM copy: warps 4-7 duplicate 0-3 (pred=0, hardware no-op)
315315
tdm_warp_bases = _compute_tdm_warp_bases([BLOCK_M, BLOCK_K], num_warps, 4)
316316

317317
warp_bases = tuple(WARP_BASES)
@@ -348,7 +348,7 @@ def test_runtime_gemm_tdm_specialized_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_B
348348
parser.add_argument("-K", type=int, default=1024, help='problem K size')
349349
parser.add_argument("--num-buffers", type=int, choices=[2, 3, 4], default=3, help='num shared memory buffers')
350350
parser.add_argument("--4warp-tdm", action="store_true", dest="four_warp_tdm",
351-
help="Use 4-warp TDM specialization (warps 4-7 skip TDM copies)")
351+
help="Use 4-warp partial TDM copy (warps 4-7 skip TDM copies)")
352352
parser.add_argument("--dump", action="store_true", help="Print out result/golden tensors")
353353
args = parser.parse_args()
354354

0 commit comments

Comments
 (0)