Skip to content

Commit a0e3e78

Browse files
[AMD][BACKEND] Enable bf16 dot2 in AMD backend pass (#6600)
This enables the basic usage of bf16 dot2 instruction in CDNA4 arch. Calculates 32-bit sum of 16-bit multiplications.
1 parent 6fedb78 commit a0e3e78

3 files changed

Lines changed: 31 additions & 3 deletions

File tree

python/test/unit/language/test_core.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3604,9 +3604,17 @@ def get_test_dot_double_rate_cases():
36043604
(16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)]
36053605

36063606

3607+
def get_test_dot_vdot2_cases():
3608+
if not is_hip_cdna():
3609+
return []
3610+
return [(4, 32, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
3611+
(4, 32, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)]
3612+
3613+
36073614
@pytest.mark.interpreter
36083615
@pytest.mark.parametrize(
36093616
"M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size",
3617+
get_test_dot_vdot2_cases() + \
36103618
get_test_dot_double_rate_cases() + \
36113619
get_test_dot_base_cases() + \
36123620
get_test_dot_mixed_sizes_cases() + \
@@ -3799,8 +3807,20 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
37993807
else:
38003808
# added atol, to loose precision for float16xfloat16->float32 case
38013809
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
3802-
if not is_cuda():
3810+
3811+
if not (is_cuda() or is_hip_cdna()):
38033812
return
3813+
3814+
if is_hip_cdna():
3815+
if M != 4:
3816+
return
3817+
amdgcn = pgm.asm['amdgcn']
3818+
if in_dtype == 'float16':
3819+
assert 'v_dot2c_f32_f16' in amdgcn
3820+
elif (in_dtype == 'bfloat16') and is_hip_cdna4():
3821+
assert 'v_dot2c_f32_bf16' in amdgcn
3822+
return
3823+
38043824
# make sure ld/st are vectorized
38053825
ptx = pgm.asm['ptx']
38063826
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ class AMDFMAVectorMultiplier : public FMAVectorMultiplier {
3636
bool dotAvailable = AMD::supportsVDot(arch);
3737
auto b = TritonLLVMOpBuilder(loc, rewriter);
3838
if (dotAvailable) {
39-
if (aElemTy.isF16() && dElemTy.isF32()) {
39+
if ((aElemTy.isF16() || aElemTy.isBF16()) && dElemTy.isF32()) {
4040
chosenOp.vectorSize = 2;
4141
chosenOp.outElemTy = f32_ty;
42-
chosenOp.intrinsicName = "llvm.amdgcn.fdot2";
42+
chosenOp.intrinsicName = aElemTy.isF16() ? "llvm.amdgcn.fdot2"
43+
: "llvm.amdgcn.fdot2.f32.bf16";
4344
chosenOp.additionalArgs = {b.false_val()};
4445
return chosenOp;
4546
}

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,13 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
11031103
return true;
11041104
}
11051105

1106+
// CDNA4 has Bf16 v_dot2
1107+
if (AMD::deduceISAFamily(arch) == ISAFamily::CDNA4 &&
1108+
dotTypes.a.isBF16() && dotTypes.b.isBF16() && dotTypes.c.isF32() &&
1109+
dotTypes.d.isF32() && k % 2 == 0) {
1110+
return true;
1111+
}
1112+
11061113
// TODO: enable this condition, when fp32 -> fp16 cast works correctly
11071114
// Consider this case as non legal, despite this case is covered by fp16
11081115
// FMA. Because v_dot expected to give both better performance and

0 commit comments

Comments
 (0)