Skip to content

Commit 12138f4

Browse files
pawelszczerbukroot
andauthored
[CONSAN] Handle memdesc selects in buffer region analysis (#10031)
Selects between different memdescs were not handled in BufferRegion analysis at all. After moving ConSan to llvm lowering we started hitting this case. Co-authored-by: root <root@codex-gb201-0.brix.pawelszczerbuk.svc.cluster.local>
1 parent 441faac commit 12138f4

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

lib/Analysis/BufferRegion.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Analysis/BufferRegion.h"
22
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
3+
#include "mlir/Dialect/Arith/IR/Arith.h"
34
#include "triton/Dialect/Triton/IR/Utility.h"
45
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
56
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
@@ -267,6 +268,16 @@ LogicalResult BufferRegionAnalysis::visitOperation(
267268
}
268269
return success();
269270
}
271+
if (auto selectOp = dyn_cast<arith::SelectOp>(op)) {
272+
if (isa<ttg::MemDescType>(selectOp.getType())) {
273+
regionInfo =
274+
RegionInfo::join(operands[1]->getValue(), operands[2]->getValue());
275+
for (auto *r : results) {
276+
propagateIfChanged(r, r->join(regionInfo));
277+
}
278+
return success();
279+
}
280+
}
270281
// "Passthrough" ops that don't modify the buffer regions.
271282
if (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp,
272283
ttg::MemDescReinterpretOp>(op)) {

test/Analysis/test-buffer-region.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
395395

396396
// -----
397397

398+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
399+
#smem = #ttg.shared_memory
400+
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
401+
402+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
403+
tt.func public @select_shared_memory_regions(%cond: i1) {
404+
%alloc_a = ttg.local_alloc {allocation.offset = 57344 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
405+
%alloc_b = ttg.local_alloc {allocation.offset = 61440 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
406+
%selected = arith.select %cond, %alloc_a, %alloc_b : !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
407+
// expected-remark @below {{Buffers: [57344, 4096], [61440, 4096]}}
408+
ttg.local_load %selected : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked>
409+
tt.return
410+
}
411+
412+
// expected-remark @below {{All Shared Regions: [57344, 4096], [61440, 4096]}}
413+
tt.func private @print_all_regions() attributes {test.print_all_used_regions} {
414+
tt.return
415+
}
416+
}
417+
418+
// -----
419+
420+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
421+
422+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
423+
tt.func public @select_tensor_memory_regions(%cond: i1) {
424+
%tm0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
425+
%tm1 = ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
426+
%selected = arith.select %cond, %tm0, %tm1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
427+
// expected-remark @below {{Buffers: [0, 128], [128, 128]}}
428+
ttng.tmem_load %selected : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32>
429+
tt.return
430+
}
431+
432+
// expected-remark @below {{All Tensor Regions: [0, 128], [128, 128]}}
433+
tt.func private @print_all_regions() attributes {test.print_all_used_regions} {
434+
tt.return
435+
}
436+
}
437+
438+
// -----
439+
398440
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
399441
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
400442
#smem = #ttg.shared_memory

0 commit comments

Comments
 (0)