Skip to content

Commit 54bed51

Browse files
[FPSAN] Fix fpsan crash with warp specialization + tmem (#9415)
Fpsan is replacing tensor memory with global scratch, and was missing correct handling of passing global memory pointers to warp_specialize op. Also, use the -Ofc mid ptx compilation mode for fpsan compilation.
1 parent 3e7c88c commit 54bed51

3 files changed

Lines changed: 87 additions & 5 deletions

File tree

lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ namespace ttng = mlir::triton::nvidia_gpu;
2525

2626
namespace {
2727

28+
static bool isValueAvailableInScope(Value value, Region *scope) {
29+
if (!scope)
30+
return false;
31+
if (auto arg = dyn_cast<BlockArgument>(value)) {
32+
Region *argRegion = arg.getOwner()->getParent();
33+
return argRegion == scope || scope->isAncestor(argRegion);
34+
}
35+
if (Operation *def = value.getDefiningOp()) {
36+
Region *defRegion = def->getParentRegion();
37+
return defRegion == scope || scope->isAncestor(defRegion);
38+
}
39+
return false;
40+
}
41+
2842
constexpr int64_t kTileM = 8;
2943
constexpr int64_t kTileN = 8;
3044

@@ -162,6 +176,8 @@ class TmemScratchManager {
162176
return std::nullopt;
163177
}
164178

179+
ptr = remapToScope(ptr, rewriter, scope, loc);
180+
165181
ScratchInfo info{ptr, tensorTy};
166182
scratchMap[memdesc][scope] = info;
167183
return info;
@@ -189,8 +205,9 @@ class TmemScratchManager {
189205
rewriter, loc, rewriter.getI32IntegerAttr(stride));
190206
auto offsetEls = arith::MulIOp::create(
191207
rewriter, loc, rewriter.getI32Type(), offsetVal, strideVal);
192-
auto ptr = tt::AddPtrOp::create(rewriter, loc, baseInfo->ptr.getType(),
193-
baseInfo->ptr, offsetEls);
208+
Value ptr = tt::AddPtrOp::create(rewriter, loc, baseInfo->ptr.getType(),
209+
baseInfo->ptr, offsetEls);
210+
ptr = remapToScope(ptr, rewriter, scope, loc);
194211
auto layout = getScratchEncoding(rewriter, memdesc, memTy);
195212
auto tensorTy = RankedTensorType::get(memTy.getShape(),
196213
memTy.getElementType(), layout);
@@ -218,8 +235,9 @@ class TmemScratchManager {
218235
rewriter, loc, rewriter.getI32IntegerAttr(stride));
219236
auto offset = arith::MulIOp::create(rewriter, loc, rewriter.getI32Type(),
220237
idx, strideVal);
221-
auto ptr = tt::AddPtrOp::create(rewriter, loc, baseInfo->ptr.getType(),
222-
baseInfo->ptr, offset);
238+
Value ptr = tt::AddPtrOp::create(rewriter, loc, baseInfo->ptr.getType(),
239+
baseInfo->ptr, offset);
240+
ptr = remapToScope(ptr, rewriter, scope, loc);
223241
auto layout = getScratchEncoding(rewriter, memdesc, memTy);
224242
auto tensorTy = RankedTensorType::get(memTy.getShape(),
225243
memTy.getElementType(), layout);
@@ -241,6 +259,7 @@ class TmemScratchManager {
241259
if (ptr.getType() != ptrTy) {
242260
ptr = tt::BitcastOp::create(rewriter, loc, ptrTy, ptr);
243261
}
262+
ptr = remapToScope(ptr, rewriter, scope, loc);
244263

245264
auto layout = getScratchEncoding(rewriter, memdesc, memTy);
246265
auto tensorTy = RankedTensorType::get(memTy.getShape(),
@@ -254,6 +273,36 @@ class TmemScratchManager {
254273
}
255274

256275
private:
276+
Value remapToScope(Value value, PatternRewriter &rewriter, Region *scope,
277+
Location loc) {
278+
if (!scope || isValueAvailableInScope(value, scope))
279+
return value;
280+
281+
auto *parentOp = scope->getParentOp();
282+
auto partitions = dyn_cast_or_null<ttg::WarpSpecializePartitionsOp>(
283+
parentOp ? parentOp : nullptr);
284+
if (!partitions)
285+
return value;
286+
287+
unsigned captureIdx = partitions.getNumOperands();
288+
for (auto [i, capture] :
289+
llvm::enumerate(partitions.getExplicitCaptures())) {
290+
if (capture == value) {
291+
captureIdx = i;
292+
break;
293+
}
294+
}
295+
296+
if (captureIdx == partitions.getNumOperands()) {
297+
partitions->insertOperands(captureIdx, value);
298+
for (Region &region : partitions.getPartitionRegions()) {
299+
region.addArgument(value.getType(), loc);
300+
}
301+
}
302+
303+
return scope->getArgument(captureIdx);
304+
}
305+
257306
DenseMap<Value, DenseMap<Region *, ScratchInfo>> scratchMap;
258307
};
259308

test/TritonGPU/fpsan.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
276276
tt.return
277277
}
278278
}
279+
280+
// -----
281+
282+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
283+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
284+
#smem = #ttg.shared_memory
285+
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
286+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
287+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} {
288+
// CHECK-LABEL: @ws_partition_tmem_load
289+
tt.func public @ws_partition_tmem_load() {
290+
// CHECK: %[[SCRATCH:.*]] = ttg.global_scratch_alloc
291+
// CHECK: ttg.warp_specialize(%{{.*}}, %{{.*}}, %{{.*}}, %[[SCRATCH]])
292+
// CHECK: partition0(%{{.*}}: !ttg.memdesc<1xi64, #{{[^,>]+}}, #smem, mutable>, %{{.*}}: !ttg.memdesc<128x128xf32, #{{[^,>]+}}, #smem, mutable>, %{{.*}}: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, %[[SCRATCH_ARG:.*]]: !tt.ptr<f32>) num_warps(4)
293+
// CHECK: %[[PTRS:.*]] = tt.splat %[[SCRATCH_ARG]] : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #blocked>
294+
// CHECK: tt.load
295+
// CHECK: ttg.local_store
296+
// CHECK-NOT: ttng.tmem_load
297+
%bar = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable>
298+
%smem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable>
299+
%buf = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
300+
ttg.warp_specialize(%bar, %smem, %buf) attributes {actualRegisters = array<i32: 32, 32>, allocation.offset = 512 : i32, requestedRegisters = array<i32: 32>, warpGroupStartIds = array<i32: 4>}
301+
default {
302+
ttg.warp_yield
303+
}
304+
partition0(%arg0: !ttg.memdesc<1xi64, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<128x128xf32, #shared, #smem, mutable>, %arg2: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) num_warps(4) {
305+
%val = ttng.tmem_load %arg2 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
306+
ttg.local_store %val, %arg1 : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable>
307+
ttg.warp_return
308+
} : (!ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) -> ()
309+
tt.return
310+
}
311+
}

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def make_cubin(self, src, metadata, opt, capability):
499499
ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
500500

501501
# Use -Ofc mid to compile ConSan code, if nothing else is specified.
502-
if "consan" in knobs.compilation.instrumentation_mode:
502+
if any(mode in knobs.compilation.instrumentation_mode for mode in ["consan", "fpsan"]):
503503
ptx_extra_options += ["-Ofc", "mid"]
504504

505505
# Add --regAllocOptLevel=2 to work around ptxas 13.x bug

0 commit comments

Comments
 (0)