Skip to content

Commit 3ebfe54

Browse files
authored
[AMD][BACKEND] Support gfx1250 async load to LDS multicast (#8719)
When lowering `ttg.async_copy_global_to_local` with `num_cta > 1` and broadcasting happening between CTAs in the same cluster we emit cluster loads. They require a CTA mask signalling which CTAs request the same memory addresses. Other cluster load variants will be added as follow up PRs.
1 parent 130d73c commit 3ebfe54

4 files changed

Lines changed: 255 additions & 20 deletions

File tree

test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,127 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
3333
tt.return
3434
}
3535
}
36+
37+
// -----
38+
39+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
40+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 8, order = [1, 0]}>
41+
#smem = #ttg.shared_memory
42+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
43+
// CHECK-LABEL: async_copy_with_swizzle
44+
tt.func public @async_copy_with_swizzle(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
45+
%arg2: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
46+
// We need the splat to allow the AxisAnalysis to work during lowering
47+
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
48+
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
49+
// CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32
50+
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
51+
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
52+
tt.return
53+
}
54+
}
55+
56+
// -----
57+
58+
// Broadcast to all CTAs so we should just see 15 (0b1111) as the broadcast mask since we have 4 CTAs per CGA
59+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
60+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
61+
#smem = #ttg.shared_memory
62+
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
63+
// CHECK-LABEL: async_load_multicast_to_all_ctas
64+
tt.func public @async_load_multicast_to_all_ctas(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
65+
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
66+
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(15 : i32) : i32
67+
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[GROUP_MASK]]
68+
69+
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
70+
tt.return
71+
}
72+
}
73+
74+
// -----
75+
76+
// 8 CTAs, 2 multicast groups of 4 CTAs each. Each group is strided by 1 so the base mask should be 0b1010101 (85) and the non free mask is -7 (~0b110)
77+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
78+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
79+
#smem = #ttg.shared_memory
80+
module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
81+
// CHECK-LABEL: async_load_multicast_to_half_ctas
82+
tt.func public @async_load_multicast_to_half_ctas(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
83+
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
84+
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
85+
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
86+
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32
87+
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
88+
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(85 : i32) : i32
89+
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
90+
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
91+
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
92+
tt.return
93+
}
94+
}
95+
96+
// -----
97+
98+
// 16 CTAs, 8 multicast groups of 2 CTAs each, each group is strided by 8 so the base mask should be 0b100000001 (257) and the non free mask is -9 (~0b1000)
99+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 8], CTAOrder = [1, 0]}>
100+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 8], CTAOrder = [1, 0]}>
101+
#smem = #ttg.shared_memory
102+
module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
103+
// CHECK-LABEL: async_load_multicast_group_of_2_strided_by_8
104+
tt.func public @async_load_multicast_group_of_2_strided_by_8(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
105+
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
106+
// Skip the first cluster id because it's emitted for address calculation
107+
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
108+
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
109+
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
110+
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
111+
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(257 : i32) : i32
112+
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
113+
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
114+
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
115+
tt.return
116+
}
117+
}
118+
119+
// -----
120+
121+
// 16 CTAs split into 16 multicast groups so we should not emit cluster load since we do not share any data
122+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 16], CTAOrder = [1, 0]}>
123+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 16], CTAOrder = [1, 0]}>
124+
#smem = #ttg.shared_memory
125+
module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
126+
// CHECK-LABEL: async_load_multi_cta_but_not_data_sharing
127+
tt.func public @async_load_multi_cta_but_not_data_sharing(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
128+
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
129+
// CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
130+
// CHECK: llvm.amdgcn.global.load.async.to.lds.b64
131+
// CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
132+
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
133+
tt.return
134+
}
135+
}
136+
137+
// -----
138+
139+
// Test with linear layout as src layout
140+
// 16 CTAs, 8 multicast groups of 2 CTAs each, each group is strided by 8 so the base mask should be 0b100000001 (257) and the non free mask is -9 (~0b1000)
141+
#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[0, 0], [0, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = [[0, 4], [0, 8], [0, 16], [0, 0]], order = [1, 0]}>
142+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 8], CTAOrder = [1, 0]}>
143+
#smem = #ttg.shared_memory
144+
module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
145+
// CHECK-LABEL: async_load_multi_cta_linear_layout
146+
tt.func public @async_load_multi_cta_linear_layout(%arg0: tensor<32x32x!tt.ptr<f32>, #linear> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
147+
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
148+
// Skip the first cluster id because it's emitted for address calculation
149+
// CHECK: llvm.amdgcn.cluster.workgroup.id.x
150+
// CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
151+
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
152+
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
153+
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(257 : i32) : i32
154+
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
155+
// CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
156+
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #linear> -> <32x32xf32, #shared, #smem, mutable>
157+
tt.return
158+
}
159+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -479,12 +479,13 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
479479
}
480480
}
481481

482-
void lowerDirectToLDSLoad(
482+
LogicalResult lowerDirectToLDSLoad(
483483
RewriterBase &rewriter, Location loc, RankedTensorType srcTy,
484484
MemDescType dstTy, SmallVector<Value> loadVals, Value llDst,
485485
Type resElemTy, unsigned vec, triton::AMD::ISAFamily isaFamily,
486486
std::function<SmallVector<Value>(RewriterBase &, Location,
487-
ArrayRef<Value>, Value, int, VectorType)>
487+
ArrayRef<Value>, Value, int, VectorType,
488+
Value)>
488489
lowerInst) const {
489490
TritonLLVMOpBuilder b(loc, rewriter);
490491
auto *ctx = rewriter.getContext();
@@ -503,10 +504,21 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
503504
sharedLayout = triton::gpu::toLinearLayout(dstTy);
504505
}
505506
auto cvt = srcLayout.invertAndCompose(sharedLayout);
507+
if (!cvt.isTrivialOver({str_attr("block")})) {
508+
return emitError(
509+
loc,
510+
"direct to lds loads do not support non-trivial block dimension");
511+
}
506512
cvt = cvt.sublayout(
507513
{str_attr("register"), str_attr("lane"), str_attr("warp")},
508514
{str_attr("offset")});
509515

516+
Value ctaMulticastMask;
517+
if (isaFamily == ISAFamily::GFX1250) {
518+
ctaMulticastMask = LLVM::AMD::emitCtaMulticastMask(
519+
rewriter, loc, targetInfo.getClusterCTAId(rewriter, loc), srcLayout);
520+
}
521+
510522
auto smemObj =
511523
LLVM::getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter);
512524
auto affineOffset = smemObj.getShmemOffset(loc, rewriter, dstTy);
@@ -557,12 +569,21 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
557569
}
558570
return smemOffset;
559571
};
572+
573+
auto lowerInstForwardMulticastMask =
574+
[&](RewriterBase &rewriter, Location loc, ArrayRef<Value> vals,
575+
Value shmemAddr, int idx, VectorType vecTy) {
576+
return lowerInst(rewriter, loc, vals, shmemAddr, idx, vecTy,
577+
ctaMulticastMask);
578+
};
579+
560580
// If we do not support scattering (GFX9) the address should be the start
561581
// address (scalar) of the warp
562582
laneId = targetInfo.supportsDirectToLDSScattering() ? laneId : b.i32_val(0);
563583
lowerLdSt(loc, ctx, cvt, loadVals, resElemTy, smemObj.getBase(),
564584
calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
565-
warpId, rewriter, targetInfo, vec, lowerInst);
585+
warpId, rewriter, targetInfo, vec, lowerInstForwardMulticastMask);
586+
return success();
566587
}
567588

568589
void emitOtherStore(RewriterBase &rewriter, Location loc,
@@ -866,8 +887,8 @@ struct BufferLoadToLocalOpConversion
866887
[this, &op, &b, &bufferEmitter, &rsrcDesc, laneId = laneId, threadPred,
867888
offsetTy, otherTy, hasOther, requiresSrcPtrSwizzling](
868889
RewriterBase &rewriter, Location loc, ArrayRef<Value> loadVals,
869-
Value shmemAddr, int startIdx,
870-
VectorType vecTy) -> SmallVector<Value> {
890+
Value shmemAddr, int startIdx, VectorType vecTy,
891+
Value multicastMask) -> SmallVector<Value> {
871892
auto [offsetElem, maskElem, otherElems, swizzleLaneOffset] =
872893
unzipLoadValues(rewriter, loc, startIdx, loadVals, offsetTy, otherTy,
873894
hasOther, vecTy.getNumElements());
@@ -905,9 +926,12 @@ struct BufferLoadToLocalOpConversion
905926
return {};
906927
};
907928

908-
lowerDirectToLDSLoad(rewriter, loc, ptrType, flatDstTy, loadVals, llDst,
909-
resElemTy, vec, targetInfo.getISAFamily(),
910-
emitBufferLoadLds);
929+
auto res = lowerDirectToLDSLoad(
930+
rewriter, loc, ptrType, flatDstTy, loadVals, llDst, resElemTy, vec,
931+
targetInfo.getISAFamily(), emitBufferLoadLds);
932+
if (failed(res)) {
933+
return failure();
934+
}
911935

912936
// Drop the result token.
913937
Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
@@ -1001,8 +1025,8 @@ struct AsyncCopyGlobalToLocalOpConversion
10011025
[this, &op, &b, laneId = laneId, threadPred, srcPtrTy, otherTy,
10021026
hasOther, requiresSrcPtrSwizzling](
10031027
RewriterBase &rewriter, Location loc, ArrayRef<Value> loadValues,
1004-
Value shmemAddr, int startIdx,
1005-
VectorType vecTy) -> SmallVector<Value> {
1028+
Value shmemAddr, int startIdx, VectorType vecTy,
1029+
Value multicastMask) -> SmallVector<Value> {
10061030
auto [srcElem, maskElem, otherElems, swizzleLaneOffset] =
10071031
unzipLoadValues(rewriter, loc, startIdx, loadValues, srcPtrTy,
10081032
otherTy, hasOther, vecTy.getNumElements());
@@ -1019,7 +1043,7 @@ struct AsyncCopyGlobalToLocalOpConversion
10191043
auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond);
10201044

10211045
emitAsyncLoad(rewriter, loc, targetInfo, vecBits, srcElem, shmemAddr,
1022-
op.getCache());
1046+
op.getCache(), multicastMask);
10231047

10241048
rewriter.setInsertionPointToStart(afterLoadBlock);
10251049

@@ -1032,9 +1056,12 @@ struct AsyncCopyGlobalToLocalOpConversion
10321056
return {};
10331057
};
10341058

1035-
lowerDirectToLDSLoad(rewriter, loc, srcTy, flatDstTy, loadVals, llDst,
1036-
resElemTy, vec, targetInfo.getISAFamily(),
1037-
emitGlobalLoadLds);
1059+
auto res = lowerDirectToLDSLoad(
1060+
rewriter, loc, srcTy, flatDstTy, loadVals, llDst, resElemTy, vec,
1061+
targetInfo.getISAFamily(), emitGlobalLoadLds);
1062+
if (failed(res)) {
1063+
return failure();
1064+
}
10381065

10391066
// Drop the result token.
10401067
Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
@@ -1046,7 +1073,8 @@ struct AsyncCopyGlobalToLocalOpConversion
10461073

10471074
void emitAsyncLoad(RewriterBase &rewriter, Location loc,
10481075
AMD::TargetInfo targetInfo, int vecBits, Value srcPtr,
1049-
Value shmemAddr, triton::CacheModifier cacheMod) const {
1076+
Value shmemAddr, triton::CacheModifier cacheMod,
1077+
Value multicastMask) const {
10501078
auto b = TritonLLVMOpBuilder(loc, rewriter);
10511079
int32_t cacheModifiers =
10521080
mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget(
@@ -1063,11 +1091,20 @@ struct AsyncCopyGlobalToLocalOpConversion
10631091
if (cacheMod != triton::CacheModifier::NONE) {
10641092
emitRemark(loc) << "cache modifiers not yet implemented on gfx1250";
10651093
}
1066-
std::string intrinsic =
1067-
"llvm.amdgcn.global.load.async.to.lds.b" + std::to_string(vecBits);
1068-
auto globalLoadLdsOp = LLVM::createLLVMIntrinsicCallOp(
1069-
rewriter, loc, intrinsic, {},
1070-
{srcPtr, shmemAddr, b.i32_val(0), b.i32_val(cacheModifiers)});
1094+
if (multicastMask) {
1095+
std::string intrinsic =
1096+
"llvm.amdgcn.cluster.load.async.to.lds.b" + std::to_string(vecBits);
1097+
auto globalLoadLdsOp = LLVM::createLLVMIntrinsicCallOp(
1098+
rewriter, loc, intrinsic, {},
1099+
{srcPtr, shmemAddr, b.i32_val(0), b.i32_val(cacheModifiers),
1100+
multicastMask});
1101+
} else {
1102+
std::string intrinsic =
1103+
"llvm.amdgcn.global.load.async.to.lds.b" + std::to_string(vecBits);
1104+
auto globalLoadLdsOp = LLVM::createLLVMIntrinsicCallOp(
1105+
rewriter, loc, intrinsic, {},
1106+
{srcPtr, shmemAddr, b.i32_val(0), b.i32_val(cacheModifiers)});
1107+
}
10711108
}
10721109
}
10731110
};

0 commit comments

Comments
 (0)