Skip to content

Commit 4804627

Browse files
authored
[AMD][gfx1250] test_dot fix for small K (#9358)
1 parent c1e2aed commit 4804627

2 files changed

Lines changed: 38 additions & 2 deletions

File tree

test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
351351
tt.return
352352
}
353353
}
354+
355+
// -----
356+
357+
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0] }>
358+
#op0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}>
359+
#op1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}>
360+
361+
// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, ctaLayout = {warp = [[0, 1], [1, 0]]}, instrShape = [16, 16, 32]}>
362+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
363+
tt.func public @wmma_dot_f16_f32_smallk(
364+
%arg0: tensor<32x8x!tt.ptr<f16>, #op0>,
365+
%arg1: tensor<8x32x!tt.ptr<f16>, #op1>,
366+
%arg2: tensor<32x32x!tt.ptr<f32>, #blocked>
367+
) {
368+
%a = tt.load %arg0 : tensor<32x8x!tt.ptr<f16>, #op0>
369+
%b = tt.load %arg1 : tensor<8x32x!tt.ptr<f16>, #op1>
370+
%c = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
371+
// CHECK: %[[OPND0:.*]] = ttg.convert_layout {{.*}} : tensor<32x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x8xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
372+
// CHECK: %[[OPND1:.*]] = ttg.convert_layout {{.*}} : tensor<8x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<8x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
373+
// CHECK: tt.dot %[[OPND0]], %[[OPND1]], %{{.*}} : tensor<32x8xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<8x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
374+
%res = tt.dot %a, %b, %c : tensor<32x8xf16, #op0> * tensor<8x32xf16, #op1> -> tensor<32x32xf32, #blocked>
375+
tt.store %arg2, %res : tensor<32x32x!tt.ptr<f32>, #blocked>
376+
tt.return
377+
}
378+
}

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,11 @@ class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
15491549
if (operandTypes.empty())
15501550
return failure();
15511551

1552+
auto kDimTensor = aShape.back();
1553+
if (kDimTensor == 1) {
1554+
return rewriter.notifyMatchFailure(dotOp,
1555+
"Skipping WMMA for dot op with K=1");
1556+
}
15521557
// check shape
15531558
FailureOr<WmmaIntrinsic> wmmaInstr =
15541559
chooseWmmaInstruction(dotOp, operandTypes, wmmaVersion);
@@ -1590,8 +1595,14 @@ class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
15901595
auto newAcc =
15911596
convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]);
15921597

1593-
// kWidth is always 8 for WMMA v3, and equals to kBase for WMMA v1/2
1594-
auto kWidth = wmmaVersion == 3 ? 8 : kBase;
1598+
auto kWidth = 0;
1599+
// Adjust kWidth=kDimTensor/2 when kDimTensor < kDim
1600+
if (kDimTensor < kDim) {
1601+
kWidth = kDimTensor / 2;
1602+
} else {
1603+
// kWidth is always 8 for WMMA v3, and equals to kBase for WMMA v1/2
1604+
kWidth = wmmaVersion == 3 ? 8 : kBase;
1605+
}
15951606
auto newAType = RankedTensorType::get(
15961607
aShape, operandTypes[0],
15971608
ttg::DotOperandEncodingAttr::get(ctx, 0, wmmaEnc, kWidth));

0 commit comments

Comments
 (0)