Skip to content

Commit 6a6cf6e

Browse files
committed
[Membar] Membar pass for clusters
The main invariant here is that: Membar for CTAs only synchronises CTAs when their buffers did not alias in the ttgir, but they alias after the Allocation pass In other words, in Gluon, the user is in charge of manually synchronising the bufferes they declare. For now, we always emit a full cluster barrier. We can improve this in the future by emitting `mbarrier`s that just synchronise subsets of the CTAs. For that we would need to be a bit more clever, as we would need to allocate some `mbarrier`s but the Allocation pass has already run... We add a number of test cases with comments of which of them are expected and which can be improved. stack-info: PR: #9318, branch: lezcano/stack/10
1 parent a8d2f7c commit 6a6cf6e

15 files changed

Lines changed: 866 additions & 31 deletions

File tree

include/triton/Analysis/Allocation.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "llvm/ADT/DenseMap.h"
66
#include "llvm/ADT/MapVector.h"
77
#include "llvm/ADT/SetVector.h"
8-
#include "llvm/Support/raw_ostream.h"
98

109
#include <limits>
1110

@@ -145,6 +144,16 @@ class Allocation {
145144
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
146145
}
147146

147+
/// Returns if the given buffer is a scratch buffer.
148+
bool isScratchBuffer(BufferId bufferId) const {
149+
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Scratch;
150+
}
151+
152+
/// Returns if the given buffer is an explicit buffer.
153+
bool isExplicitBuffer(BufferId bufferId) const {
154+
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Explicit;
155+
}
156+
148157
/// Returns the size of total shared memory allocated
149158
size_t getSharedMemorySize() const { return sharedMemorySize; }
150159

include/triton/Analysis/Membar.h

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,41 @@
44
#include "Allocation.h"
55

66
#include "llvm/Support/raw_ostream.h"
7+
#include <functional>
78
#include <set>
89
#include <tuple>
910

1011
namespace mlir {
1112

1213
class OpBuilder;
14+
struct AllocationSlice;
1315

1416
/// Callback to allow backend to provide more information on whether a barrier
1517
/// is needed between two operations. Even though two operations access the same
1618
/// shared memory they may not require a barrier in between them.
1719
using MembarFilterFn =
18-
std::function<bool(Operation *, Operation *, Allocation *)>;
20+
std::function<bool(Operation *, Operation *, bool /*lhsIsRead*/,
21+
bool /*rhsIsRead*/, Allocation *)>;
22+
23+
/// Slice-level filter to allow backends to ignore specific aliasing cases.
24+
using MembarSliceFilterFn =
25+
std::function<bool(const AllocationSlice &, const AllocationSlice &,
26+
bool /*lhsIsRead*/, bool /*rhsIsRead*/, Allocation *)>;
1927

2028
// Represents the access to a slice of an allocation
2129
// It contains information both on physical memory (the interval) and a
2230
// logical view on it (layout, subslice offsets and shape for the access)
2331
struct AllocationSlice {
2432
public:
2533
// Create allocation slice from a value, collecting subslice offsets
26-
AllocationSlice(Value value, Interval<size_t> allocationInterval);
34+
AllocationSlice(Value value, Interval<size_t> allocationInterval,
35+
Allocation::BufferId bufferId);
2736

2837
// Builder for accesses that represent accesses to the whole
2938
// allocation (scratch buffers, ArriveBarrierOp, ..)
3039
AllocationSlice(Interval<size_t> interval)
31-
: allocationInterval(interval), accessTy(nullptr) {}
40+
: allocationInterval(interval), accessTy(nullptr),
41+
bufferId(Allocation::InvalidBufferId) {}
3242

3343
bool operator<(const AllocationSlice &other) const {
3444
return asTuple() < other.asTuple();
@@ -43,19 +53,25 @@ struct AllocationSlice {
4353
// Returns true if it can't prove the AllocationSlices are disjoint.
4454
bool intersects(const AllocationSlice &other) const;
4555

56+
Allocation::BufferId getBufferId() const { return bufferId; }
57+
4658
void print(raw_ostream &os) const;
4759

4860
private:
49-
std::tuple<Interval<size_t>, const void *, llvm::ArrayRef<int64_t>>
61+
std::tuple<Interval<size_t>, Allocation::BufferId, const void *,
62+
llvm::ArrayRef<int64_t>>
5063
asTuple() const {
51-
return {allocationInterval, accessTy.getAsOpaquePointer(), subsliceOffsets};
64+
return {allocationInterval, bufferId, accessTy.getAsOpaquePointer(),
65+
subsliceOffsets};
5266
}
5367
// Offsets from subslice. Empty when offsets are unknown
5468
SmallVector<int64_t> subsliceOffsets;
5569
// The allocated interval for this buffer
5670
Interval<size_t> allocationInterval;
5771
// Type of the memory descriptor for this access
5872
triton::gpu::MemDescType accessTy;
73+
// Buffer id for partial sync on wait_barrier deps.
74+
Allocation::BufferId bufferId;
5975
};
6076

6177
struct BlockInfo {
@@ -103,15 +119,19 @@ struct BlockInfo {
103119

104120
/// Returns true if Slices in two BlockInfo objects are intersected.
105121
bool isIntersected(const BlockInfo &other, MembarFilterFn filter,
106-
Allocation *allocation) const {
107-
return /*RAW*/ isIntersected(syncWriteSlices, other.syncReadSlices, filter,
108-
allocation) ||
122+
Allocation *allocation,
123+
MembarSliceFilterFn sliceFilter = nullptr) const {
124+
return /*RAW*/ isIntersected(syncWriteSlices, other.syncReadSlices,
125+
/*lhsIsRead=*/false, /*rhsIsRead=*/true,
126+
filter, sliceFilter, allocation) ||
109127
/*WAR*/
110-
isIntersected(syncReadSlices, other.syncWriteSlices, filter,
111-
allocation) ||
128+
isIntersected(syncReadSlices, other.syncWriteSlices,
129+
/*lhsIsRead=*/true, /*rhsIsRead=*/false, filter,
130+
sliceFilter, allocation) ||
112131
/*WAW*/
113-
isIntersected(syncWriteSlices, other.syncWriteSlices, filter,
114-
allocation);
132+
isIntersected(syncWriteSlices, other.syncWriteSlices,
133+
/*lhsIsRead=*/false, /*rhsIsRead=*/false, filter,
134+
sliceFilter, allocation);
115135
}
116136

117137
/// Clears the slices because a barrier is inserted.
@@ -130,14 +150,19 @@ struct BlockInfo {
130150

131151
private:
132152
bool isIntersected(const SliceMapT &lhsSlices, const SliceMapT &rhsSlices,
133-
MembarFilterFn filter, Allocation *allocation) const {
153+
bool lhsIsRead, bool rhsIsRead, MembarFilterFn filter,
154+
MembarSliceFilterFn sliceFilter,
155+
Allocation *allocation) const {
134156
for (auto &lhs : lhsSlices)
135157
for (auto &rhs : rhsSlices)
136158
if (lhs.first.intersects(rhs.first))
137-
for (auto lhsOp : lhs.second)
138-
for (auto rhsOp : rhs.second)
139-
if (!filter || !filter(lhsOp, rhsOp, allocation))
140-
return true;
159+
if (!sliceFilter || !sliceFilter(lhs.first, rhs.first, lhsIsRead,
160+
rhsIsRead, allocation))
161+
for (auto lhsOp : lhs.second)
162+
for (auto rhsOp : rhs.second)
163+
if (!filter ||
164+
!filter(lhsOp, rhsOp, lhsIsRead, rhsIsRead, allocation))
165+
return true;
141166
return false;
142167
}
143168
};
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_CLUSTERBARRIERINSERTION_H_
2+
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_CLUSTERBARRIERINSERTION_H_
3+
4+
#include "triton/Analysis/Allocation.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
namespace nvidia_gpu {
9+
10+
/// Inserts cluster barriers (cluster_arrive + cluster_wait) using the provided
11+
/// shared-memory allocation analysis.
12+
void runClusterBarrierInsertion(ModuleAllocation &moduleAllocation,
13+
int computeCapability);
14+
15+
} // namespace nvidia_gpu
16+
} // namespace triton
17+
} // namespace mlir
18+
19+
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_CLUSTERBARRIERINSERTION_H_

lib/Analysis/Membar.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
namespace mlir {
1111

1212
AllocationSlice::AllocationSlice(Value value,
13-
Interval<size_t> allocationInterval)
14-
: allocationInterval(allocationInterval) {
13+
Interval<size_t> allocationInterval,
14+
Allocation::BufferId bufferId)
15+
: allocationInterval(allocationInterval), bufferId(bufferId) {
1516
auto accessTy = cast<triton::gpu::MemDescType>(value.getType());
1617
this->accessTy = accessTy;
1718

@@ -69,6 +70,9 @@ void AllocationSlice::print(raw_ostream &os) const {
6970
os << "interval=[" << allocationInterval.start() << ","
7071
<< allocationInterval.end() << ")";
7172

73+
if (bufferId != Allocation::InvalidBufferId)
74+
os << " buffer=" << bufferId;
75+
7276
os << " offsets=[";
7377
if (!subsliceOffsets.empty()) {
7478
llvm::interleaveComma(subsliceOffsets, os);
@@ -244,6 +248,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
244248
auto containsLocalBarrier = [](Operation *op) {
245249
if (isa<gpu::BarrierOp>(op))
246250
return true;
251+
if (isa<triton::nvidia_gpu::ClusterWaitOp>(op))
252+
return true;
247253
if (isa<triton::gpu::WarpSpecializePartitionsOp>(op))
248254
return true;
249255
if (auto barrier = dyn_cast<triton::gpu::BarrierOp>(op))
@@ -287,7 +293,7 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
287293
for (auto bufferId : allocation->getAllBufferIdsWithAliases(value)) {
288294
if (bufferId != Allocation::InvalidBufferId) {
289295
auto interval = allocation->getAllocatedInterval(bufferId);
290-
auto slice = AllocationSlice(value, interval);
296+
auto slice = AllocationSlice(value, interval, bufferId);
291297

292298
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
293299
curBlockInfo.syncWriteSlices[slice].insert(op);

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonNvidiaGPUTransforms
2+
ClusterBarrierInsertion.cpp
23
CheckMatmulTwoCTAs.cpp
34
FenceInsertion.cpp
45
InterleaveTMem.cpp

0 commit comments

Comments
 (0)