Skip to content

Commit 05fc47f

Browse files
authored
[Membar] Membar pass for clusters (#9318)
Stacked PRs: * #9327 * __->__#9318 --- --- --- ### [Membar] Membar pass for clusters The main invariant here is that: Membar for CTAs only synchronises CTAs when their buffers did not alias in the ttgir, but they alias after the Allocation pass In other words, in Gluon, the user is in charge of manually synchronising the bufferes they declare. For now, we always emit a full cluster barrier. We can improve this in the future by emitting `mbarrier`s that just synchronise subsets of the CTAs. For that we would need to be a bit more clever, as we would need to allocate some `mbarrier`s but the Allocation pass has already run... We add a number of test cases with comments of which of them are expected and which can be improved.
1 parent c155e4a commit 05fc47f

15 files changed

Lines changed: 861 additions & 31 deletions

File tree

include/triton/Analysis/Allocation.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "llvm/ADT/DenseMap.h"
66
#include "llvm/ADT/MapVector.h"
77
#include "llvm/ADT/SetVector.h"
8-
#include "llvm/Support/raw_ostream.h"
98

109
#include <limits>
1110

@@ -145,6 +144,11 @@ class Allocation {
145144
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
146145
}
147146

147+
/// Returns if the given buffer is an explicit buffer.
148+
bool isExplicitBuffer(BufferId bufferId) const {
149+
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Explicit;
150+
}
151+
148152
/// Returns the size of total shared memory allocated
149153
size_t getSharedMemorySize() const { return sharedMemorySize; }
150154

include/triton/Analysis/Membar.h

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,41 @@
44
#include "Allocation.h"
55

66
#include "llvm/Support/raw_ostream.h"
7+
#include <functional>
78
#include <set>
89
#include <tuple>
910

1011
namespace mlir {
1112

1213
class OpBuilder;
14+
struct AllocationSlice;
1315

1416
/// Callback to allow backend to provide more information on whether a barrier
1517
/// is needed between two operations. Even though two operations access the same
1618
/// shared memory they may not require a barrier in between them.
1719
using MembarFilterFn =
18-
std::function<bool(Operation *, Operation *, Allocation *)>;
20+
std::function<bool(Operation *, Operation *, bool /*lhsIsRead*/,
21+
bool /*rhsIsRead*/, Allocation *)>;
22+
23+
/// Slice-level filter to allow backends to ignore specific aliasing cases.
24+
using MembarSliceFilterFn =
25+
std::function<bool(const AllocationSlice &, const AllocationSlice &,
26+
bool /*lhsIsRead*/, bool /*rhsIsRead*/, Allocation *)>;
1927

2028
// Represents the access to a slice of an allocation
2129
// It contains information both on physical memory (the interval) and a
2230
// logical view on it (layout, subslice offsets and shape for the access)
2331
struct AllocationSlice {
2432
public:
2533
// Create allocation slice from a value, collecting subslice offsets
26-
AllocationSlice(Value value, Interval<size_t> allocationInterval);
34+
AllocationSlice(Value value, Interval<size_t> allocationInterval,
35+
Allocation::BufferId bufferId);
2736

2837
// Builder for accesses that represent accesses to the whole
2938
// allocation (scratch buffers, ArriveBarrierOp, ..)
3039
AllocationSlice(Interval<size_t> interval)
31-
: allocationInterval(interval), accessTy(nullptr) {}
40+
: allocationInterval(interval), accessTy(nullptr),
41+
bufferId(Allocation::InvalidBufferId) {}
3242

3343
bool operator<(const AllocationSlice &other) const {
3444
return asTuple() < other.asTuple();
@@ -43,19 +53,25 @@ struct AllocationSlice {
4353
// Returns true if it can't prove the AllocationSlices are disjoint.
4454
bool intersects(const AllocationSlice &other) const;
4555

56+
Allocation::BufferId getBufferId() const { return bufferId; }
57+
4658
void print(raw_ostream &os) const;
4759

4860
private:
49-
std::tuple<Interval<size_t>, const void *, llvm::ArrayRef<int64_t>>
61+
std::tuple<Interval<size_t>, Allocation::BufferId, const void *,
62+
llvm::ArrayRef<int64_t>>
5063
asTuple() const {
51-
return {allocationInterval, accessTy.getAsOpaquePointer(), subsliceOffsets};
64+
return {allocationInterval, bufferId, accessTy.getAsOpaquePointer(),
65+
subsliceOffsets};
5266
}
5367
// Offsets from subslice. Empty when offsets are unknown
5468
SmallVector<int64_t> subsliceOffsets;
5569
// The allocated interval for this buffer
5670
Interval<size_t> allocationInterval;
5771
// Type of the memory descriptor for this access
5872
triton::gpu::MemDescType accessTy;
73+
// Buffer id for partial sync on wait_barrier deps.
74+
Allocation::BufferId bufferId;
5975
};
6076

6177
struct BlockInfo {
@@ -103,15 +119,19 @@ struct BlockInfo {
103119

104120
/// Returns true if Slices in two BlockInfo objects are intersected.
105121
bool isIntersected(const BlockInfo &other, MembarFilterFn filter,
106-
Allocation *allocation) const {
107-
return /*RAW*/ isIntersected(syncWriteSlices, other.syncReadSlices, filter,
108-
allocation) ||
122+
Allocation *allocation,
123+
MembarSliceFilterFn sliceFilter = nullptr) const {
124+
return /*RAW*/ isIntersected(syncWriteSlices, other.syncReadSlices,
125+
/*lhsIsRead=*/false, /*rhsIsRead=*/true,
126+
filter, sliceFilter, allocation) ||
109127
/*WAR*/
110-
isIntersected(syncReadSlices, other.syncWriteSlices, filter,
111-
allocation) ||
128+
isIntersected(syncReadSlices, other.syncWriteSlices,
129+
/*lhsIsRead=*/true, /*rhsIsRead=*/false, filter,
130+
sliceFilter, allocation) ||
112131
/*WAW*/
113-
isIntersected(syncWriteSlices, other.syncWriteSlices, filter,
114-
allocation);
132+
isIntersected(syncWriteSlices, other.syncWriteSlices,
133+
/*lhsIsRead=*/false, /*rhsIsRead=*/false, filter,
134+
sliceFilter, allocation);
115135
}
116136

117137
/// Clears the slices because a barrier is inserted.
@@ -130,14 +150,19 @@ struct BlockInfo {
130150

131151
private:
132152
bool isIntersected(const SliceMapT &lhsSlices, const SliceMapT &rhsSlices,
133-
MembarFilterFn filter, Allocation *allocation) const {
153+
bool lhsIsRead, bool rhsIsRead, MembarFilterFn filter,
154+
MembarSliceFilterFn sliceFilter,
155+
Allocation *allocation) const {
134156
for (auto &lhs : lhsSlices)
135157
for (auto &rhs : rhsSlices)
136158
if (lhs.first.intersects(rhs.first))
137-
for (auto lhsOp : lhs.second)
138-
for (auto rhsOp : rhs.second)
139-
if (!filter || !filter(lhsOp, rhsOp, allocation))
140-
return true;
159+
if (!sliceFilter || !sliceFilter(lhs.first, rhs.first, lhsIsRead,
160+
rhsIsRead, allocation))
161+
for (auto lhsOp : lhs.second)
162+
for (auto rhsOp : rhs.second)
163+
if (!filter ||
164+
!filter(lhsOp, rhsOp, lhsIsRead, rhsIsRead, allocation))
165+
return true;
141166
return false;
142167
}
143168
};
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_CLUSTERBARRIERINSERTION_H_
2+
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_CLUSTERBARRIERINSERTION_H_
3+
4+
#include "triton/Analysis/Allocation.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
namespace nvidia_gpu {
9+
10+
/// Inserts cluster barriers (cluster_arrive + cluster_wait) using the provided
11+
/// shared-memory allocation analysis.
12+
void runClusterBarrierInsertion(ModuleAllocation &moduleAllocation,
13+
int computeCapability);
14+
15+
} // namespace nvidia_gpu
16+
} // namespace triton
17+
} // namespace mlir
18+
19+
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_CLUSTERBARRIERINSERTION_H_

lib/Analysis/Membar.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
namespace mlir {
1111

1212
AllocationSlice::AllocationSlice(Value value,
13-
Interval<size_t> allocationInterval)
14-
: allocationInterval(allocationInterval) {
13+
Interval<size_t> allocationInterval,
14+
Allocation::BufferId bufferId)
15+
: allocationInterval(allocationInterval), bufferId(bufferId) {
1516
auto accessTy = cast<triton::gpu::MemDescType>(value.getType());
1617
this->accessTy = accessTy;
1718

@@ -69,6 +70,9 @@ void AllocationSlice::print(raw_ostream &os) const {
6970
os << "interval=[" << allocationInterval.start() << ","
7071
<< allocationInterval.end() << ")";
7172

73+
if (bufferId != Allocation::InvalidBufferId)
74+
os << " buffer=" << bufferId;
75+
7276
os << " offsets=[";
7377
if (!subsliceOffsets.empty()) {
7478
llvm::interleaveComma(subsliceOffsets, os);
@@ -244,6 +248,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
244248
auto containsLocalBarrier = [](Operation *op) {
245249
if (isa<gpu::BarrierOp>(op))
246250
return true;
251+
if (isa<triton::nvidia_gpu::ClusterWaitOp>(op))
252+
return true;
247253
if (isa<triton::gpu::WarpSpecializePartitionsOp>(op))
248254
return true;
249255
if (auto barrier = dyn_cast<triton::gpu::BarrierOp>(op))
@@ -287,7 +293,7 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
287293
for (auto bufferId : allocation->getAllBufferIdsWithAliases(value)) {
288294
if (bufferId != Allocation::InvalidBufferId) {
289295
auto interval = allocation->getAllocatedInterval(bufferId);
290-
auto slice = AllocationSlice(value, interval);
296+
auto slice = AllocationSlice(value, interval, bufferId);
291297

292298
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
293299
curBlockInfo.syncWriteSlices[slice].insert(op);

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonNvidiaGPUTransforms
2+
ClusterBarrierInsertion.cpp
23
CheckMatmulTwoCTAs.cpp
34
FenceInsertion.cpp
45
InterleaveTMem.cpp
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.h"
2+
#include "triton/Analysis/Allocation.h"
3+
#include "triton/Analysis/Membar.h"
4+
#include "triton/Analysis/Utility.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
7+
8+
#include "mlir/Interfaces/FunctionInterfaces.h"
9+
#include "mlir/Interfaces/SideEffectInterfaces.h"
10+
#include "llvm/ADT/SmallVector.h"
11+
#include "llvm/Support/ErrorHandling.h"
12+
13+
namespace mlir {
14+
namespace triton {
15+
namespace nvidia_gpu {
16+
17+
namespace {
18+
19+
namespace ttg = mlir::triton::gpu;
20+
namespace ttng = mlir::triton::nvidia_gpu;
21+
22+
static bool isDistributedMultiCTAOp(Operation *op, bool isRead) {
23+
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(op)) {
24+
if (!isRead)
25+
return false;
26+
auto srcTy = cvt.getSrc().getType();
27+
auto dstTy = cvt.getType();
28+
auto kBlock = StringAttr::get(op->getContext(), "block");
29+
auto conversion = minimalCvtLayout(srcTy, dstTy);
30+
return conversion.hasInDim(kBlock);
31+
}
32+
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
33+
if (!isRead)
34+
return false;
35+
auto srcTy = reduce.getInputTypes()[0];
36+
auto splitNum = ttg::getCTASplitNum(srcTy.getEncoding());
37+
return splitNum[reduce.getAxis()] > 1;
38+
}
39+
if (auto mma = dyn_cast<ttng::TCGen5MMAOp>(op)) {
40+
return mma.getTwoCtas();
41+
} else if (auto mmaScaled = dyn_cast<ttng::TCGen5MMAScaledOp>(op)) {
42+
// TODO: Change when we support scaled MMA with 2CTAs
43+
assert(!ttng::getModuleTwoCTAs(op->getParentOfType<ModuleOp>()) &&
44+
"Scaled MMA with 2CTAs not supported");
45+
return false;
46+
} else if (auto tma = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
47+
return tma.getMulticast();
48+
}
49+
return false;
50+
}
51+
52+
static bool isPreAllocAliasSliceFilter(const AllocationSlice &lhsSlice,
53+
const AllocationSlice &rhsSlice,
54+
bool /*lhsIsRead*/, bool /*rhsIsRead*/,
55+
Allocation *allocation) {
56+
auto bufferId = lhsSlice.getBufferId();
57+
return bufferId != Allocation::InvalidBufferId &&
58+
bufferId == rhsSlice.getBufferId() &&
59+
allocation->isExplicitBuffer(bufferId);
60+
}
61+
62+
class ClusterBarrierAnalysis : public MembarOrFenceAnalysis {
63+
public:
64+
ClusterBarrierAnalysis() = default;
65+
explicit ClusterBarrierAnalysis(Allocation *allocation, MembarFilterFn filter)
66+
: MembarOrFenceAnalysis(allocation, filter) {}
67+
68+
private:
69+
void update(Operation *op, BlockInfo *blockInfo,
70+
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder) override;
71+
72+
void insertClusterBarrier(Operation *op, OpBuilder *builder);
73+
};
74+
75+
void ClusterBarrierAnalysis::insertClusterBarrier(Operation *op,
76+
OpBuilder *builder) {
77+
OpBuilder::InsertionGuard guard(*builder);
78+
ttng::ClusterArriveOp::create(*builder, op->getLoc(), /*relaxed=*/false);
79+
ttng::ClusterWaitOp::create(*builder, op->getLoc());
80+
}
81+
82+
void ClusterBarrierAnalysis::update(Operation *op, BlockInfo *blockInfo,
83+
FuncBlockInfoMapT *funcBlockInfoMap,
84+
OpBuilder *builder) {
85+
if (isa<ttng::ClusterWaitOp>(op)) {
86+
blockInfo->sync();
87+
return;
88+
}
89+
90+
BlockInfo curBlockInfo;
91+
auto scratchBufferId = Allocation::InvalidBufferId;
92+
if (isa<triton::CallOp>(op)) {
93+
auto callOpInterface = dyn_cast<CallOpInterface>(op);
94+
if (auto callee =
95+
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
96+
curBlockInfo = funcBlockInfoMap->lookup(callee);
97+
} else {
98+
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
99+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
100+
effectInstances;
101+
memEffects.getEffects(effectInstances);
102+
for (auto effectInstance : effectInstances) {
103+
if (auto value = effectInstance.getValue()) {
104+
for (auto bufferId : allocation->getBufferIds(value)) {
105+
if (bufferId != Allocation::InvalidBufferId) {
106+
auto interval = allocation->getAllocatedInterval(bufferId);
107+
auto slice = AllocationSlice(value, interval, bufferId);
108+
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
109+
curBlockInfo.syncWriteSlices[slice].insert(op);
110+
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
111+
curBlockInfo.syncReadSlices[slice].insert(op);
112+
}
113+
}
114+
}
115+
}
116+
}
117+
scratchBufferId = allocation->getBufferId(op);
118+
}
119+
120+
// Scratch buffer operations consist of a series of shared memory operations
121+
// starting from a shared memory write, followed by a series of shared memory
122+
// read/write operations, and ending with a shared memory read, i.e., shared
123+
// memory write -> ... -> shared memory read.
124+
if (scratchBufferId != Allocation::InvalidBufferId) {
125+
if (!curBlockInfo.syncReadSlices.empty() ||
126+
!curBlockInfo.syncWriteSlices.empty()) {
127+
llvm::report_fatal_error(
128+
"scratch buffer operations should not have any shared memory "
129+
"dependencies");
130+
}
131+
132+
auto interval = allocation->getAllocatedInterval(scratchBufferId);
133+
auto scratchSlice = AllocationSlice(interval);
134+
curBlockInfo.syncWriteSlices[scratchSlice].insert(op);
135+
136+
auto insertClusterBarrierNeeded = blockInfo->isIntersected(
137+
curBlockInfo, filter, allocation, isPreAllocAliasSliceFilter);
138+
if (insertClusterBarrierNeeded) {
139+
builder->setInsertionPoint(op);
140+
insertClusterBarrier(op, builder);
141+
}
142+
143+
// Clear prior distributed dependencies if we have inserted a cluster
144+
// barrier, or if the scratch op itself performs a cluster-level sync.
145+
bool hasClusterSync = isDistributedMultiCTAOp(op, /*isRead=*/true);
146+
if (insertClusterBarrierNeeded || hasClusterSync)
147+
blockInfo->sync();
148+
149+
curBlockInfo.syncReadSlices[scratchSlice].insert(op);
150+
} else if (blockInfo->isIntersected(curBlockInfo, filter, allocation,
151+
isPreAllocAliasSliceFilter)) {
152+
builder->setInsertionPoint(op);
153+
insertClusterBarrier(op, builder);
154+
blockInfo->sync();
155+
}
156+
157+
blockInfo->join(curBlockInfo);
158+
}
159+
160+
} // namespace
161+
162+
void runClusterBarrierInsertion(ModuleAllocation &moduleAllocation,
163+
int computeCapability) {
164+
ModuleOp mod = moduleAllocation.getModuleOp();
165+
if (computeCapability < 90)
166+
return;
167+
if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1)
168+
return;
169+
170+
MembarFilterFn filterFn = [](Operation *lhs, Operation *rhs, bool lhsIsRead,
171+
bool rhsIsRead, Allocation * /*allocation*/) {
172+
// Filter ops that do not touch distributed shared memory. Whether the
173+
// aliasing was already present in TTGIR is handled per-allocation slice.
174+
bool lhsDist = isDistributedMultiCTAOp(lhs, lhsIsRead);
175+
bool rhsDist = isDistributedMultiCTAOp(rhs, rhsIsRead);
176+
if (!lhsDist && !rhsDist)
177+
return true;
178+
return false;
179+
};
180+
181+
ModuleMembarOrFenceAnalysis<ClusterBarrierAnalysis> analysis(
182+
&moduleAllocation, filterFn);
183+
analysis.run();
184+
}
185+
186+
} // namespace nvidia_gpu
187+
} // namespace triton
188+
} // namespace mlir

0 commit comments

Comments
 (0)