Skip to content

Commit 2868f7a

Browse files
authored
[AMD] Fix scale layouts for batched WMMA scaled (#9545)
Fix the linear layout for wmma scale when it has a batch dimension. And add tests for batched wmma scaled, where we will distribute warps along the batch dimension to run wmma in parallel.
1 parent 7b2682d commit 2868f7a

2 files changed

Lines changed: 98 additions & 10 deletions

File tree

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,12 +1414,7 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14141414
CGAEncodingAttr cgaLayout) {
14151415
using basisT = std::vector<std::vector<int32_t>>;
14161416
unsigned rank = dotOperandShape.size();
1417-
SmallVector<int32_t> order;
1418-
if (rank == 3) {
1419-
order = {1, 0, 2};
1420-
} else {
1421-
order = {1, 0};
1422-
}
1417+
bool hasBatchDim = rank == 3;
14231418
auto outDimNames = standardOutDimNames(ctx, rank);
14241419

14251420
StringAttr kRegister = StringAttr::get(ctx, "register");
@@ -1433,8 +1428,8 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14331428
// - B: [K, N]
14341429
// - aScale: [M, K / 32 or 16]
14351430
// - bScale: [N, K / 32 or 16]
1436-
auto dimK = outDimNames[order[0]];
1437-
auto dimNonK = outDimNames[order[1]];
1431+
auto dimK = outDimNames[rank - 1];
1432+
auto dimNonK = outDimNames[rank - 2];
14381433

14391434
// Each lane holds kWidth=4 consecutive values along the K dim.
14401435
// The first 16 lanes are distributed along the nonK dim.
@@ -1445,13 +1440,23 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14451440
LinearLayout::identity1D(16, kLane, dimNonK) *
14461441
LinearLayout::zeros1D(2, kLane, dimNonK);
14471442

1448-
unsigned mnDim = dotOperandIdx == 0 ? rank - 2 : rank - 1;
1449-
14501443
// If the shape along the K dim is larger than kWidth, repeat this
14511444
// pattern to fill the K dim.
14521445
tileLayout *= LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK);
14531446

1447+
if (hasBatchDim) {
1448+
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[0]);
1449+
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[0]);
1450+
}
1451+
14541452
if (dotOperandIdx == 1) {
1453+
// ctaLayout comes from the dot operand. For B in scaled dot,
1454+
// - the operand is ordered as [K, N]
1455+
// - the scale is ordered as [N, K / 32 or 16].
1456+
// Swap the last two dims of ctaLayout to match the tileLayout
1457+
SmallVector<int32_t> order = {1, 0};
1458+
if (hasBatchDim)
1459+
order = {0, 2, 1};
14551460
ctaLayout = transposeLinearLayout(ctaLayout, order);
14561461
}
14571462

third_party/amd/python/test/test_gluon_gfx1250.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,89 @@ def kernel(c_ptr, a_ptr, a_scale_ptr, b_ptr, b_scale_ptr, #
887887
torch.testing.assert_close(c.cpu(), c_torch, atol=1e-5, rtol=2e-5)
888888

889889

890+
@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250")
891+
@pytest.mark.parametrize("B", [4])
892+
@pytest.mark.parametrize("M, N, K", get_test_mxfp_block_mnk())
893+
@pytest.mark.parametrize("a_type, b_type", get_test_mxfp_variants())
894+
def test_amd_wmma_scaled_batched(B, M, N, K, a_type, b_type):
895+
896+
@gluon.constexpr_function
897+
def _slice_layout(layout, indices):
898+
for i in reversed(indices):
899+
layout = ttgl.SliceLayout(i, layout)
900+
return layout
901+
902+
@gluon.jit
903+
def _offsets(dim0, dim1, dim2, layout):
904+
return ttgl.arange(0, dim0, layout=_slice_layout(layout, [1, 2]))[:, None, None] * (dim1 * dim2) + \
905+
ttgl.arange(0, dim1, layout=_slice_layout(layout, [0, 2]))[None, :, None] * dim2 + \
906+
ttgl.arange(0, dim2, layout=_slice_layout(layout, [0, 1]))[None, None, :]
907+
908+
@gluon.jit
909+
def kernel(c_ptr, a_ptr, a_scale_ptr, b_ptr, b_scale_ptr, #
910+
a_type: ttgl.constexpr, b_type: ttgl.constexpr, #
911+
BLOCK_B: ttgl.constexpr, BLOCK_M: ttgl.constexpr, #
912+
BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr):
913+
DIV_FACTOR_A: ttgl.constexpr = 2 if a_type == "e2m1" else 1
914+
DIV_FACTOR_B: ttgl.constexpr = 2 if b_type == "e2m1" else 1
915+
916+
warp_bases: ttgl.constexpr = [[1, 0, 0], [2, 0, 0]]
917+
wmma_layout: ttgl.constexpr = \
918+
ttgl.amd.AMDWMMALayout(3, True, warp_bases, instr_shape=[16, 16, 128], rank=3)
919+
wmma_layout_packed: ttgl.constexpr = \
920+
ttgl.amd.AMDWMMALayout(3, True, warp_bases, instr_shape=[16, 16, 64], rank=3)
921+
a_layout: ttgl.constexpr = \
922+
ttgl.DotOperandLayout(0, wmma_layout_packed if a_type == "e2m1" else wmma_layout, 16)
923+
b_layout: ttgl.constexpr = \
924+
ttgl.DotOperandLayout(1, wmma_layout_packed if b_type == "e2m1" else wmma_layout, 16)
925+
a_scale_layout: ttgl.constexpr = \
926+
get_wmma_scale_layout(a_layout, [BLOCK_B, BLOCK_M, BLOCK_K // 32])
927+
b_scale_layout: ttgl.constexpr = \
928+
get_wmma_scale_layout(b_layout, [BLOCK_B, BLOCK_N, BLOCK_K // 32])
929+
930+
a_offs = _offsets(BLOCK_B, BLOCK_M, BLOCK_K // DIV_FACTOR_A, a_layout)
931+
a = ttgl.load(a_ptr + a_offs)
932+
b_offs = _offsets(BLOCK_B, BLOCK_K // DIV_FACTOR_B, BLOCK_N, b_layout)
933+
b = ttgl.load(b_ptr + b_offs)
934+
935+
a_scale_offs = _offsets(BLOCK_B, BLOCK_M, BLOCK_K // 32, a_scale_layout)
936+
a_scale = ttgl.load(a_scale_ptr + a_scale_offs)
937+
b_scale_offs = _offsets(BLOCK_B, BLOCK_N, BLOCK_K // 32, b_scale_layout)
938+
b_scale = ttgl.load(b_scale_ptr + b_scale_offs)
939+
940+
zero = ttgl.zeros([BLOCK_B, BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=wmma_layout)
941+
c = ttgl.amd.gfx1250.wmma_scaled(a, a_scale, a_type, b, b_scale, b_type, zero)
942+
c = c.to(c_ptr.dtype.element_ty)
943+
944+
c_offs = _offsets(BLOCK_B, BLOCK_M, BLOCK_N, wmma_layout)
945+
ttgl.store(c_ptr + c_offs, c)
946+
947+
torch.manual_seed(42)
948+
a, a_ref = zip(*[create_mxfp_operand(0, M, K, a_type) for _ in range(B)])
949+
b, b_ref = zip(*[create_mxfp_operand(1, K, N, b_type) for _ in range(B)])
950+
a_scale, a_scale_ref = zip(*[create_mxfp_scale(0, M, K) for _ in range(B)])
951+
b_scale, b_scale_ref = zip(*[create_mxfp_scale(1, K, N) for _ in range(B)])
952+
953+
a = torch.stack(a, dim=0)
954+
b = torch.stack(b, dim=0)
955+
a_scale = torch.stack(a_scale, dim=0)
956+
b_scale = torch.stack(b_scale, dim=0).permute(0, 2, 1).contiguous()
957+
958+
a_ref = torch.stack(a_ref, dim=0)
959+
b_ref = torch.stack(b_ref, dim=0)
960+
a_scale_ref = torch.stack(a_scale_ref, dim=0)
961+
b_scale_ref = torch.stack(b_scale_ref, dim=0)
962+
963+
a, a_scale = a.cuda(), a_scale.cuda()
964+
b, b_scale = b.cuda(), b_scale.cuda()
965+
966+
c = torch.zeros((B, M, N), dtype=torch.float32).cuda()
967+
kernel[(1, )](c, a, a_scale, b, b_scale, a_type, b_type, B, M, N, K, num_warps=4)
968+
969+
c_torch = (a_ref * a_scale_ref) @ (b_ref * b_scale_ref)
970+
torch.testing.assert_close(c.cpu(), c_torch, atol=1e-5, rtol=2e-5)
971+
972+
890973
@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250")
891974
@pytest.mark.parametrize("M, N, K", [(16, 16, 128), (32, 32, 128), (32, 32, 256), (32, 32, 512), (64, 64, 128),
892975
(128, 128, 256)])

0 commit comments

Comments
 (0)