Skip to content

Commit af02d05

Browse files
authored
[AMD][NFC][BACKEND] Use ROCDL ops for (cluster) async load (#9410)
Replaces intrinsics for (cluster) async loads with ROCDL Ops.
1 parent 54bed51 commit af02d05

2 files changed

Lines changed: 50 additions & 30 deletions

File tree

test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
1010
// We need the splat to allow the AxisAnalysis to work during lowering
1111
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
1212
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
13-
// CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32
14-
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
13+
// CHECK-COUNT-8: rocdl.global.load.async.to.lds.b32
14+
// CHECK-NOT: rocdl.global.load.async.to.lds
1515
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
1616
tt.return
1717
}
@@ -27,8 +27,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
2727
tt.func public @async_load_strided_into_lds_with_swizzle(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
2828
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
2929
// Each thread loads 256 contiguous bits so we split into 2 128bit loads. This was not possible on GFX9
30-
// CHECK-COUNT-2: llvm.amdgcn.global.load.async.to.lds.b128
31-
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
30+
// CHECK-COUNT-2: rocdl.global.load.async.to.lds.b128
31+
// CHECK-NOT: rocdl.global.load.async.to.lds
3232
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
3333
tt.return
3434
}
@@ -46,8 +46,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
4646
// We need the splat to allow the AxisAnalysis to work during lowering
4747
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
4848
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
49-
// CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32
50-
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
49+
// CHECK-COUNT-8: rocdl.global.load.async.to.lds.b32
50+
// CHECK-NOT: rocdl.global.load.async.to.lds
5151
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
5252
tt.return
5353
}
@@ -64,7 +64,7 @@ module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6464
tt.func public @async_load_multicast_to_all_ctas(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
6565
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
6666
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(15 : i32) : i32
67-
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[GROUP_MASK]]
67+
// CHECK: rocdl.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[GROUP_MASK]]
6868

6969
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
7070
tt.return
@@ -86,7 +86,7 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
8686
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
8787
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(85 : i32) : i32
8888
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
89-
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
89+
// CHECK: rocdl.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
9090
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
9191
tt.return
9292
}
@@ -108,7 +108,7 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
108108
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
109109
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(257 : i32) : i32
110110
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
111-
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
111+
// CHECK: rocdl.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
112112
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
113113
tt.return
114114
}
@@ -124,9 +124,9 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
124124
// CHECK-LABEL: async_load_multi_cta_but_not_data_sharing
125125
tt.func public @async_load_multi_cta_but_not_data_sharing(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
126126
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
127-
// CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
128-
// CHECK: llvm.amdgcn.global.load.async.to.lds.b64
129-
// CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
127+
// CHECK-NOT: rocdl.cluster.load.async.to.lds
128+
// CHECK: rocdl.global.load.async.to.lds.b64
129+
// CHECK-NOT: rocdl.cluster.load.async.to.lds
130130
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
131131
tt.return
132132
}
@@ -149,7 +149,7 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha
149149
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
150150
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(257 : i32) : i32
151151
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
152-
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
152+
// CHECK: rocdl.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
153153
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #linear> -> <32x32xf32, #shared, #smem, mutable>
154154
tt.return
155155
}
@@ -204,8 +204,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
204204
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
205205
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
206206
// Each thread loads 8 elements with 32-bit loads
207-
// CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32
208-
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
207+
// CHECK-COUNT-8: rocdl.global.load.async.to.lds.b32
208+
// CHECK-NOT: rocdl.global.load.async.to.lds
209209
%2 = ttg.async_copy_global_to_local %1, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
210210
tt.return
211211
}
@@ -246,8 +246,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
246246
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
247247
// minInterval=2 limits vectorization to 2 elements (64 bits)
248248
// Each thread handles 8 elements -> 4 x 64-bit loads
249-
// CHECK-COUNT-4: llvm.amdgcn.global.load.async.to.lds.b64
250-
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
249+
// CHECK-COUNT-4: rocdl.global.load.async.to.lds.b64
250+
// CHECK-NOT: rocdl.global.load.async.to.lds
251251
%2 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
252252
tt.return
253253
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,19 +1027,39 @@ struct AsyncCopyGlobalToLocalOpConversion
10271027
if (cacheMod != triton::CacheModifier::NONE) {
10281028
emitRemark(loc) << "cache modifiers not yet implemented on gfx1250";
10291029
}
1030-
if (multicastMask) {
1031-
std::string intrinsic =
1032-
"llvm.amdgcn.cluster.load.async.to.lds.b" + std::to_string(vecBits);
1033-
auto globalLoadLdsOp = LLVM::createLLVMIntrinsicCallOp(
1034-
rewriter, loc, intrinsic, {},
1035-
{srcPtr, shmemAddr, b.i32_val(0), b.i32_val(cacheModifiers),
1036-
multicastMask});
1037-
} else {
1038-
std::string intrinsic =
1039-
"llvm.amdgcn.global.load.async.to.lds.b" + std::to_string(vecBits);
1040-
auto globalLoadLdsOp = LLVM::createLLVMIntrinsicCallOp(
1041-
rewriter, loc, intrinsic, {},
1042-
{srcPtr, shmemAddr, b.i32_val(0), b.i32_val(cacheModifiers)});
1030+
switch (vecBits) {
1031+
case 32:
1032+
if (multicastMask)
1033+
ROCDL::ClusterLoadAsyncToLDSB32Op::create(
1034+
rewriter, loc, srcPtr, shmemAddr, 0, cacheModifiers,
1035+
multicastMask, nullptr, nullptr, nullptr);
1036+
else
1037+
ROCDL::GlobalLoadAsyncToLDSB32Op::create(rewriter, loc, srcPtr,
1038+
shmemAddr, 0, cacheModifiers,
1039+
nullptr, nullptr, nullptr);
1040+
break;
1041+
case 64:
1042+
if (multicastMask)
1043+
ROCDL::ClusterLoadAsyncToLDSB64Op::create(
1044+
rewriter, loc, srcPtr, shmemAddr, 0, cacheModifiers,
1045+
multicastMask, nullptr, nullptr, nullptr);
1046+
else
1047+
ROCDL::GlobalLoadAsyncToLDSB64Op::create(rewriter, loc, srcPtr,
1048+
shmemAddr, 0, cacheModifiers,
1049+
nullptr, nullptr, nullptr);
1050+
break;
1051+
case 128:
1052+
if (multicastMask)
1053+
ROCDL::ClusterLoadAsyncToLDSB128Op::create(
1054+
rewriter, loc, srcPtr, shmemAddr, 0, cacheModifiers,
1055+
multicastMask, nullptr, nullptr, nullptr);
1056+
else
1057+
ROCDL::GlobalLoadAsyncToLDSB128Op::create(
1058+
rewriter, loc, srcPtr, shmemAddr, 0, cacheModifiers, nullptr,
1059+
nullptr, nullptr);
1060+
break;
1061+
default:
1062+
llvm_unreachable("Unsupported vec size for async load");
10431063
}
10441064
}
10451065
}

0 commit comments

Comments
 (0)