Skip to content

Commit 278e956

Browse files
authored
[AMD][GLUON] host TDM descriptor support for 1D-5D on gfx1250 (triton-lang#8977)
Currently host TDM descriptors only support 2D tiles. This PR supports 1D-5D in host descriptors, bringing to parity with on-device descriptor creation. It also attempts to re-use some code across driver and compiler into a header `TDMCommon.h` for the warp/block distribution calculations. Also disabled SGPR preload.
1 parent 5a8358c commit 278e956

File tree

7 files changed

+193
-60
lines changed

7 files changed

+193
-60
lines changed

python/triton/experimental/gluon/amd/gfx1250.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ class TensorDescriptor:
1717

1818
def __post_init__(self):
1919
ndim = len(self.shape)
20-
# TODO: support 1D-5D tensor descriptors
21-
assert ndim == 2, f"Expected 2 dimensions but got {ndim} dimensions"
20+
assert 1 <= ndim <= 5, f"Expected 1-5 dimensions but got {ndim} dimensions"
2221
assert len(self.strides) == ndim, f"Expected {ndim} strides but got {len(self.strides)}"
2322
assert len(self.block_shape) == ndim, \
2423
f"Expected block_shape to have {ndim} dimensions but got {len(self.strides)}"

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ def make_llir(src, metadata, options):
394394
# Hint the compiler that we'd like the firmware to set the kernel arguments
395395
# to user SGPRs so that the kernel does not need to s_load its arguments
396396
# from memory.
397-
amd.set_all_fn_arg_inreg(fns[0])
397+
if options.arch != "gfx1250":
398+
amd.set_all_fn_arg_inreg(fns[0])
398399

399400
if knobs.compilation.enable_asan:
400401
default_libdir = Path(__file__).parent / 'lib'

third_party/amd/backend/driver.c

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#include <stdio.h>
99
#include <stdlib.h>
1010

11+
// Include shared TDM utilities
12+
#include "TDMCommon.h"
13+
1114
typedef struct {
1215
uint32_t group0_0;
1316
uint32_t group0_1;
@@ -21,6 +24,14 @@ typedef struct {
2124
uint32_t group1_5;
2225
uint32_t group1_6;
2326
uint32_t group1_7;
27+
uint32_t group2_0;
28+
uint32_t group2_1;
29+
uint32_t group2_2;
30+
uint32_t group2_3;
31+
uint32_t group3_0;
32+
uint32_t group3_1;
33+
uint32_t group3_2;
34+
uint32_t group3_3;
2435
} TDMDescriptor;
2536

2637
typedef struct {
@@ -54,36 +65,39 @@ static PyTypeObject PyTDMDescriptorType = {
5465
.tp_dealloc = (destructor)PyTDMDescriptor_dealloc,
5566
};
5667

57-
// TODO: Both host-side and device-side TDM descriptor follow the same encoding
58-
// format. Consider to add a common utility to remove duplicate code.
68+
// Encodes a TDM descriptor. Supports 1D-5D tensors.
69+
// Uses the same encoding format as createTDMDescriptor in TDMUtility.cpp.
5970
static bool encodeTDMDescriptor(TDMDescriptor *desc, int elementBitWidth,
6071
uint32_t *blockSize, int numWarps,
6172
int padInterval, int padAmount, uint32_t *shape,
6273
uint32_t *strides, uint64_t globalAddress,
6374
int rank) {
64-
// NYI: TDM > 2D cases
65-
if (rank != 2)
75+
if (rank < 1 || rank > 5)
6676
return false;
6777

68-
// Get warp distribution
69-
uint32_t numWarpsDim0 = numWarps;
70-
for (; numWarpsDim0 > blockSize[0]; numWarpsDim0 /= 2)
71-
;
72-
uint32_t numWarpsDim1 = numWarps / numWarpsDim0;
73-
if (!(numWarpsDim0 > 0 && blockSize[1] % numWarpsDim1 == 0))
74-
return false;
78+
memset(desc, 0, sizeof(TDMDescriptor));
7579

76-
uint32_t blockSize0 = (blockSize[0] + numWarpsDim0 - 1) / numWarpsDim0;
77-
uint32_t blockSize1 = (blockSize[1] + numWarpsDim1 - 1) / numWarpsDim1;
80+
// Convert to int64_t for shared function and get adjusted block sizes
81+
int64_t blockShape64[5], adjustedBlockSize64[5];
82+
for (int i = 0; i < rank; ++i)
83+
blockShape64[i] = blockSize[i];
84+
tdmGetAdjustedBlockShape(blockShape64, rank, numWarps, adjustedBlockSize64);
85+
86+
// Convert back to uint32_t
87+
uint32_t adjustedBlockSize[5];
88+
for (int i = 0; i < rank; ++i)
89+
adjustedBlockSize[i] = (uint32_t)adjustedBlockSize64[i];
7890

7991
// group0 (128 bits / 4 dwords) effective bit encoding:
92+
// [1:0]: pred (to be filled later)
93+
// [63:32]: lds address (to be filled later)
8094
// [120:64]: global address
8195
// [127:126]: type - currently always set to 0x2
8296
desc->group0_2 = (uint32_t)(globalAddress & 0xFFFFFFFF);
83-
desc->group0_3 = (uint32_t)((globalAddress >> 32) & 0x01FFFFFF);
84-
desc->group0_3 |= (0x1 << 31);
97+
desc->group0_3 = (uint32_t)((globalAddress >> 32) & 0x7FFFFFFF) | (0x1 << 31);
8598

8699
// group1 (256 bits / 8 dwords) effective bit encoding:
100+
// [15:0]: multicast mask
87101
// [17:16]: data size - log2(element size in bytes)
88102
// [20]: enable padding
89103
// [24:22]: pad interval - log2(pad interval in dwords) - 1
@@ -92,26 +106,72 @@ static bool encodeTDMDescriptor(TDMDescriptor *desc, int elementBitWidth,
92106
// [111:80]: tensor shape dim outer
93107
// [127:112]: block shape dim inner
94108
// [143:128]: block shape dim outer
109+
// [159:144]: tile_dim2
95110
// [207:160]: tensor stride dim outer (we only use 32 bits)
111+
// [255:208]: tensor stride dim 2 (48 bits)
96112
int elementSizeInBytes = elementBitWidth / 8;
97-
int dataSize = log2(elementSizeInBytes);
98-
desc->group1_0 = (dataSize << 16);
113+
int dataSize = (int)log2(elementSizeInBytes);
99114
int dwordSize = 32;
100115
int padIntervalInDwords = padInterval * elementBitWidth / dwordSize;
101116
int padAmountInDwords = padAmount * elementBitWidth / dwordSize;
117+
118+
desc->group1_0 = (dataSize << 16);
102119
if (padIntervalInDwords > 0 && padAmountInDwords > 0) {
103-
int log2PadInterval = log2(padIntervalInDwords);
120+
int log2PadInterval = (int)log2(padIntervalInDwords);
104121
desc->group1_0 |= (1 << 20);
105122
desc->group1_0 |= ((log2PadInterval - 1) << 22);
106123
desc->group1_0 |= ((padAmountInDwords - 1) << 25);
107124
}
108-
desc->group1_1 = (shape[1] << 16);
109-
desc->group1_2 = (shape[1] >> 16);
110-
desc->group1_2 |= (shape[0] << 16);
111-
desc->group1_3 = (shape[0] >> 16);
112-
desc->group1_3 |= (blockSize1 << 16);
113-
desc->group1_4 = (blockSize0 & 0xFFFF);
114-
desc->group1_5 = strides[0];
125+
126+
// Encode tensor shapes (48-bit encoding, indices from end: rank-1 is inner)
127+
desc->group1_1 = (shape[rank - 1] << 16);
128+
desc->group1_2 = (shape[rank - 1] >> 16);
129+
130+
if (rank >= 2) {
131+
desc->group1_2 |= (shape[rank - 2] << 16);
132+
desc->group1_3 = (shape[rank - 2] >> 16);
133+
}
134+
135+
// Block shapes
136+
desc->group1_3 |= (adjustedBlockSize[rank - 1] << 16);
137+
if (rank >= 2)
138+
desc->group1_4 = (adjustedBlockSize[rank - 2] & 0xFFFF);
139+
if (rank >= 3)
140+
desc->group1_4 |= (adjustedBlockSize[rank - 3] << 16);
141+
142+
// Strides
143+
if (rank >= 2)
144+
desc->group1_5 = strides[rank - 2];
145+
if (rank >= 3) {
146+
desc->group1_6 = (strides[rank - 3] << 16);
147+
desc->group1_7 = (strides[rank - 3] >> 16);
148+
}
149+
150+
// group2 (128 bits / 4 dwords) for 3D-5D tensors:
151+
// [31:0]: tensor_dim2 (3rd dimension from end)
152+
// [63:32]: tensor_dim3 (4th dimension from end)
153+
// [111:64]: tensor_dim2_stride (48 bits, we use 32 bits)
154+
// [127:112]: tile_dim3
155+
if (rank >= 3) {
156+
desc->group2_0 = shape[rank - 3];
157+
if (rank >= 4) {
158+
desc->group2_1 = shape[rank - 4];
159+
desc->group2_2 = strides[rank - 4];
160+
desc->group2_3 = (adjustedBlockSize[rank - 4] << 16);
161+
}
162+
}
163+
164+
// group3 (128 bits / 4 dwords) for 4D-5D tensors:
165+
// [47:0]: tensor_dim3_stride (48 bits, we use 32 bits)
166+
// [79:48]: tensor_dim4 (5th dimension from end)
167+
// [95:80]: tile_dim4
168+
// [127:96]: reserved
169+
if (rank == 5) {
170+
desc->group3_0 = strides[rank - 5];
171+
desc->group3_1 = (shape[rank - 5] << 16);
172+
desc->group3_2 = (shape[rank - 5] >> 16);
173+
desc->group3_2 |= (adjustedBlockSize[rank - 5] << 16);
174+
}
115175

116176
return true;
117177
}
@@ -388,16 +448,16 @@ static PyObject *createTDMDescriptor(PyObject *self, PyObject *args) {
388448
PyObject *shapeFast = NULL;
389449
PyObject *stridesFast = NULL;
390450

391-
uint32_t blockSizeInt[2];
392-
uint32_t shapeInt[2];
393-
uint32_t stridesInt[2];
451+
uint32_t blockSizeInt[5];
452+
uint32_t shapeInt[5];
453+
uint32_t stridesInt[5];
394454

395455
blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
396456
if (!blockSizeFast)
397457
goto cleanup;
398458
int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
399-
if (rank != 2) {
400-
PyErr_SetString(PyExc_RuntimeError, "rank must be 2");
459+
if (rank == 0 || rank > 5) {
460+
PyErr_SetString(PyExc_RuntimeError, "rank must be between 1 and 5");
401461
goto cleanup;
402462
}
403463

third_party/amd/backend/driver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,14 @@ def format_of(ty):
368368
uint32_t group1_5;
369369
uint32_t group1_6;
370370
uint32_t group1_7;
371+
uint32_t group2_0;
372+
uint32_t group2_1;
373+
uint32_t group2_2;
374+
uint32_t group2_3;
375+
uint32_t group3_0;
376+
uint32_t group3_1;
377+
uint32_t group3_2;
378+
uint32_t group3_3;
371379
}} TDMDescriptor;
372380
373381
typedef struct {{
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef TRITON_THIRD_PARTY_AMD_BACKEND_INCLUDE_TDMCOMMON_H
2+
#define TRITON_THIRD_PARTY_AMD_BACKEND_INCLUDE_TDMCOMMON_H
3+
4+
//===----------------------------------------------------------------------===//
5+
// C-compatible TDM utilities shared between host-side (driver.c) and
6+
// device-side (TDMUtility.cpp) code.
7+
//
8+
// This is intentionally kept header-only to avoid introducing
9+
// dependencies between the compiler and runtime components.
10+
//===----------------------------------------------------------------------===//
11+
12+
#include <stdint.h>
13+
14+
// Compute warp distribution across dimensions.
15+
// Distributes warps starting from the first dimension, assigning as many
16+
// warps as possible without exceeding the block shape.
17+
static inline void tdmGetWarpDistribution(const int64_t *blockShape,
18+
int numDims, int numWarps,
19+
int *warpsOut) {
20+
for (int i = 0; i < numDims; ++i)
21+
warpsOut[i] = 1;
22+
23+
int remainingWarps = numWarps;
24+
for (int i = 0; i < numDims && remainingWarps > 1; ++i) {
25+
while (remainingWarps > 1 && warpsOut[i] * 2 <= blockShape[i]) {
26+
warpsOut[i] *= 2;
27+
remainingWarps /= 2;
28+
}
29+
}
30+
31+
if (remainingWarps > 1)
32+
warpsOut[numDims - 1] *= remainingWarps;
33+
}
34+
35+
// Compute per-warp block sizes after distributing warps.
36+
// Only adjusts first 2 dimensions; higher dimensions remain unchanged.
37+
static inline void tdmGetAdjustedBlockShape(const int64_t *blockShape,
38+
int numDims, int numWarps,
39+
int64_t *adjustedOut) {
40+
int warps[5];
41+
tdmGetWarpDistribution(blockShape, numDims, numWarps, warps);
42+
43+
if (numDims >= 2) {
44+
adjustedOut[0] = (blockShape[0] + warps[0] - 1) / warps[0];
45+
adjustedOut[1] = (blockShape[1] + warps[1] - 1) / warps[1];
46+
} else {
47+
adjustedOut[0] = (blockShape[0] + numWarps - 1) / numWarps;
48+
}
49+
50+
// Higher dimensions are not divided by warps
51+
for (int i = 2; i < numDims; ++i)
52+
adjustedOut[i] = blockShape[i];
53+
}
54+
55+
#endif // TRITON_THIRD_PARTY_AMD_BACKEND_INCLUDE_TDMCOMMON_H

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include "triton/Tools/LayoutUtils.h"
44
#include <optional>
55

6+
// Include shared C-compatible TDM utilities
7+
#include "../../backend/include/TDMCommon.h"
8+
69
namespace mlir::LLVM::AMD {
710
namespace {
811

@@ -54,30 +57,16 @@ decodeTDMDescriptor(RewriterBase &rewriter, Location loc,
5457
return {srcPtr, tensorShape, tensorStride};
5558
}
5659

60+
// C++ wrapper for the shared tdmGetWarpDistribution function
5761
SmallVector<int> getWarpDistribution(ArrayRef<int64_t> blockShape,
5862
int numWarps) {
59-
SmallVector<int> warps(blockShape.size(), 1);
60-
int remainingWarps = numWarps;
61-
62-
// Distribute warps across dimensions, starting from the first dimension
63-
for (size_t i = 0; i < blockShape.size() && remainingWarps > 1; ++i) {
64-
// Try to assign as many warps as possible to this dimension
65-
// without exceeding the block shape
66-
while (remainingWarps > 1 && warps[i] * 2 <= blockShape[i]) {
67-
warps[i] *= 2;
68-
remainingWarps /= 2;
69-
}
70-
}
71-
72-
// If there are still remaining warps, assign them to the last dimension
73-
// This ensures we use all available warps
74-
if (remainingWarps > 1) {
75-
warps[blockShape.size() - 1] *= remainingWarps;
76-
}
63+
int numDims = blockShape.size();
64+
SmallVector<int> warps(numDims);
65+
tdmGetWarpDistribution(blockShape.data(), numDims, numWarps, warps.data());
7766

7867
// Verify the distribution is valid
7968
int totalWarps = 1;
80-
for (size_t i = 0; i < warps.size(); ++i) {
69+
for (int i = 0; i < numDims; ++i) {
8170
totalWarps *= warps[i];
8271
assert(blockShape[i] % warps[i] == 0 &&
8372
"Block shape must be divisible by warp distribution");

third_party/amd/python/test/test_gluon_gfx1250.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,8 +1220,8 @@ def test_runtime_tensor_fill(M, N, BLOCK_M, BLOCK_N, NUM_BUFFERS):
12201220

12211221

12221222
@gluon.jit
1223-
def tensor_descriptor_load_store_nd_kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE, out_shape, out_strides,
1224-
SHARED_LAYOUT: ttgl.constexpr):
1223+
def tensor_descriptor_load_store_nd_kernel_device_tdm(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE, out_shape,
1224+
out_strides, SHARED_LAYOUT: ttgl.constexpr):
12251225
ndim: ttgl.constexpr = len(BLOCK_SHAPE)
12261226
desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=a_ptr, shape=shape, strides=strides,
12271227
block_shape=BLOCK_SHAPE, layout=SHARED_LAYOUT)
@@ -1238,10 +1238,23 @@ def tensor_descriptor_load_store_nd_kernel(out_ptr, a_ptr, shape, strides, BLOCK
12381238
ttgl.amd.gfx1250.tdm.async_wait(0)
12391239

12401240

1241+
@gluon.jit
1242+
def tensor_descriptor_load_store_nd_kernel_host_tdm(out_desc, inp_desc):
1243+
ndim: ttgl.constexpr = len(inp_desc.block_shape)
1244+
offs = (0, ) * ndim
1245+
block_shared = ttgl.allocate_shared_memory(inp_desc.dtype, shape=inp_desc.block_shape, layout=inp_desc.layout)
1246+
ttgl.amd.gfx1250.tdm.async_load(inp_desc, offs, block_shared)
1247+
ttgl.amd.gfx1250.tdm.async_wait(0)
1248+
1249+
ttgl.amd.gfx1250.tdm.async_store(out_desc, offs, block_shared)
1250+
ttgl.amd.gfx1250.tdm.async_wait(0)
1251+
1252+
12411253
@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
12421254
@pytest.mark.parametrize("INNER_BLOCK", [4, 8, 16, 32, 64, 128])
12431255
@pytest.mark.parametrize("dtype_str", sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}))
1244-
def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK):
1256+
@pytest.mark.parametrize("TDM_TYPE", ["DEVICE_TDM", "HOST_TDM"])
1257+
def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK, TDM_TYPE):
12451258
SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1,
12461259
order=[ndim - 1 - i for i in range(ndim)])
12471260

@@ -1263,9 +1276,16 @@ def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK):
12631276
inp = inp.cuda()
12641277
out = out.cuda()
12651278

1266-
constexpr_block_shape = tuple(ttgl.constexpr(v) for v in BLOCK_SHAPE)
1267-
k = tensor_descriptor_load_store_nd_kernel[(1, )](out, inp, inp.shape, inp.stride(), constexpr_block_shape,
1268-
out.shape, out.stride(), SHARED_LAYOUT)
1279+
if TDM_TYPE == "DEVICE_TDM":
1280+
constexpr_block_shape = tuple(ttgl.constexpr(v) for v in BLOCK_SHAPE)
1281+
k = tensor_descriptor_load_store_nd_kernel_device_tdm[(1, )](out, inp, inp.shape,
1282+
inp.stride(), constexpr_block_shape, out.shape,
1283+
out.stride(), SHARED_LAYOUT)
1284+
else:
1285+
assert TDM_TYPE == "HOST_TDM"
1286+
inp_desc = gluon.amd.gfx1250.TensorDescriptor.from_tensor(inp, list(BLOCK_SHAPE), layout=SHARED_LAYOUT)
1287+
out_desc = gluon.amd.gfx1250.TensorDescriptor.from_tensor(out, list(BLOCK_SHAPE), layout=SHARED_LAYOUT)
1288+
k = tensor_descriptor_load_store_nd_kernel_host_tdm[(1, )](out_desc, inp_desc)
12691289

12701290
amdgcn = k.asm["amdgcn"]
12711291
for pattern in ("tensor_load_to_lds", "tensor_store_from_lds", "s_wait_tensorcnt 0x0"):
@@ -1305,8 +1325,9 @@ def test_tensor_descriptor_load_store_invalid_blocksize():
13051325

13061326
# Expect compilation to fail due to block size exceeding maximum
13071327
try:
1308-
tensor_descriptor_load_store_nd_kernel[(1, )](out, inp, inp.shape, inp.stride(), constexpr_block_shape,
1309-
out.shape, out.stride(), SHARED_LAYOUT)
1328+
tensor_descriptor_load_store_nd_kernel_device_tdm[(1, )](out, inp, inp.shape,
1329+
inp.stride(), constexpr_block_shape, out.shape,
1330+
out.stride(), SHARED_LAYOUT)
13101331
pytest.fail(
13111332
f"Expected compilation to fail for block size {INNER_BLOCK} (2^17) > 65536 (2^16), but it succeeded")
13121333
except Exception as e:

0 commit comments

Comments
 (0)