Skip to content

Commit a0f063c

Browse files
committed
[Membar] Fix non-trivial function smem offsets
Codex rightly identified that we were not considering the offsets of functions in our membar analysis at #9318 (comment) Codex then went on and fixed it and added a regression test. stack-info: PR: #9327, branch: lezcano/stack/11
1 parent 6a6cf6e commit a0f063c

4 files changed

Lines changed: 85 additions & 4 deletions

File tree

include/triton/Analysis/Membar.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ struct AllocationSlice {
5555

5656
Allocation::BufferId getBufferId() const { return bufferId; }
5757

58+
AllocationSlice translated(size_t offset,
59+
bool invalidateBufferId = false) const {
60+
AllocationSlice shifted = *this;
61+
shifted.allocationInterval = Interval<size_t>(
62+
allocationInterval.start() + offset, allocationInterval.end() + offset);
63+
if (invalidateBufferId)
64+
shifted.bufferId = Allocation::InvalidBufferId;
65+
return shifted;
66+
}
67+
5868
void print(raw_ostream &os) const;
5969

6070
private:
@@ -167,6 +177,26 @@ struct BlockInfo {
167177
}
168178
};
169179

180+
inline BlockInfo translateBlockInfoToCallsite(const BlockInfo &calleeBlockInfo,
181+
size_t callOffset) {
182+
BlockInfo translatedBlockInfo;
183+
auto translateSlices = [&](const BlockInfo::SliceMapT &srcSlices,
184+
BlockInfo::SliceMapT &dstSlices) {
185+
for (const auto &[slice, ops] : srcSlices) {
186+
auto translatedSlice =
187+
slice.translated(callOffset, /*invalidateBufferId=*/true);
188+
auto &dstOps = dstSlices[translatedSlice];
189+
dstOps.insert(ops.begin(), ops.end());
190+
}
191+
};
192+
193+
translateSlices(calleeBlockInfo.syncReadSlices,
194+
translatedBlockInfo.syncReadSlices);
195+
translateSlices(calleeBlockInfo.syncWriteSlices,
196+
translatedBlockInfo.syncWriteSlices);
197+
return translatedBlockInfo;
198+
}
199+
170200
//===----------------------------------------------------------------------===//
171201
// Shared Memory Barrier Analysis
172202
//===----------------------------------------------------------------------===//

lib/Analysis/Membar.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,14 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
279279
// Inter-function dependencies
280280
auto callOpInterface = dyn_cast<CallOpInterface>(op);
281281
if (auto callee =
282-
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
283-
curBlockInfo = funcBlockInfoMap->lookup(callee);
282+
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable())) {
283+
auto calleeBlockInfo = funcBlockInfoMap->lookup(callee);
284+
auto callBufferId = allocation->getBufferId(op);
285+
size_t callOffset = 0;
286+
if (callBufferId != Allocation::InvalidBufferId)
287+
callOffset = allocation->getAllocatedInterval(callBufferId).start();
288+
curBlockInfo = translateBlockInfoToCallsite(calleeBlockInfo, callOffset);
289+
}
284290
} else {
285291
// Intra-function dependencies
286292
if (auto memoryEffectOpInterface = dyn_cast<MemoryEffectOpInterface>(op)) {

lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,14 @@ void ClusterBarrierAnalysis::update(Operation *op, BlockInfo *blockInfo,
9292
if (isa<triton::CallOp>(op)) {
9393
auto callOpInterface = dyn_cast<CallOpInterface>(op);
9494
if (auto callee =
95-
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
96-
curBlockInfo = funcBlockInfoMap->lookup(callee);
95+
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable())) {
96+
auto calleeBlockInfo = funcBlockInfoMap->lookup(callee);
97+
auto callBufferId = allocation->getBufferId(op);
98+
size_t callOffset = 0;
99+
if (callBufferId != Allocation::InvalidBufferId)
100+
callOffset = allocation->getAllocatedInterval(callBufferId).start();
101+
curBlockInfo = translateBlockInfoToCallsite(calleeBlockInfo, callOffset);
102+
}
97103
} else {
98104
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
99105
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>

test/Analysis/test-membar.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,3 +1232,42 @@ module attributes {ttg.target = "cuda:90", "ttg.num-warps" = 8 : i32} {
12321232
tt.return
12331233
}
12341234
}
1235+
1236+
// -----
1237+
1238+
#blockedLarge = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
1239+
#sharedLarge = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
1240+
#blockedCallSrc = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
1241+
#mmaCall = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
1242+
#blockedCallDst = #ttg.dot_op<{opIdx = 0, parent = #mmaCall, kWidth = 2}>
1243+
#smem = #ttg.shared_memory
1244+
1245+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1246+
tt.func private @callee_call_offset_membar() -> tensor<128x32xf16, #blockedCallDst> {
1247+
%cst = arith.constant dense<0.0> : tensor<128x32xf16, #blockedCallSrc>
1248+
%cvt = ttg.convert_layout %cst : tensor<128x32xf16, #blockedCallSrc> -> tensor<128x32xf16, #blockedCallDst>
1249+
tt.return %cvt : tensor<128x32xf16, #blockedCallDst>
1250+
}
1251+
1252+
// The call's virtual buffer is offset by the large allocation. The
1253+
// subsequent scratch op should alias at the same offset and require a membar.
1254+
// CHECK-LABEL: @caller_call_offset_membar
1255+
// CHECK: tt.call @callee_call_offset_membar{{.*}}allocation.offset = [[CALL_OFFSET:[1-9][0-9]*]]
1256+
// CHECK: ttg.barrier local
1257+
// CHECK-NEXT: ttg.convert_layout{{.*}}allocation.offset = [[CALL_OFFSET]]
1258+
tt.func @caller_call_offset_membar() -> tensor<128x32xf16, #blockedCallDst> {
1259+
%large = arith.constant dense<0> : tensor<65536xi8, #blockedLarge>
1260+
%buf = ttg.local_alloc : () -> !ttg.memdesc<65536xi8, #sharedLarge, #smem, mutable>
1261+
ttg.local_store %large, %buf : tensor<65536xi8, #blockedLarge> -> !ttg.memdesc<65536xi8, #sharedLarge, #smem, mutable>
1262+
1263+
%call = tt.call @callee_call_offset_membar() : () -> tensor<128x32xf16, #blockedCallDst>
1264+
1265+
%cst = arith.constant dense<0.0> : tensor<128x32xf16, #blockedCallSrc>
1266+
%cvt = ttg.convert_layout %cst : tensor<128x32xf16, #blockedCallSrc> -> tensor<128x32xf16, #blockedCallDst>
1267+
%sum = arith.addf %call, %cvt : tensor<128x32xf16, #blockedCallDst>
1268+
1269+
%ld = ttg.local_load %buf : !ttg.memdesc<65536xi8, #sharedLarge, #smem, mutable> -> tensor<65536xi8, #blockedLarge>
1270+
ttg.local_dealloc %buf : !ttg.memdesc<65536xi8, #sharedLarge, #smem, mutable>
1271+
tt.return %sum : tensor<128x32xf16, #blockedCallDst>
1272+
}
1273+
}

0 commit comments

Comments
 (0)