Skip to content

Commit 434aecb

Browse files
authored
[BACKEND] Re-order WS lowering and NVGPU lowering (#9535)
Better layering as NVGPU is meant to be at the same level of abstraction of LLVM. This also avoid bugs when lowering prologue/epilogue of the kernel
1 parent f2895fa commit 434aecb

5 files changed

Lines changed: 72 additions & 70 deletions

File tree

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -119,68 +119,6 @@ llvm.func @tensor_memory_base_warpgroup() attributes {nvvm.kernel = 1 : ui1, nvv
119119

120120
}
121121

122-
// -----
123-
124-
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
125-
126-
// CHECK-LABEL: @warpid_warp_specialize
127-
llvm.func @warpid_warp_specialize() {
128-
// CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32)
129-
// CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x
130-
// CHECK: [[ID:%.*]] = llvm.udiv [[TIDX]], [[C32]]
131-
// CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]]
132-
%0 = ttg.warp_id
133-
// CHECK: "use"([[UNIFORM]])
134-
"use"(%0) : (i32) -> ()
135-
136-
// CHECK: ttg.warp_specialize
137-
ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 6, 4>}
138-
// CHECK: default
139-
default {
140-
// CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x
141-
// CHECK: [[ID:%.*]] = llvm.udiv [[TIDX]], [[C32]]
142-
// CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]]
143-
%1 = ttg.warp_id
144-
// CHECK: "use"([[UNIFORM]])
145-
"use"(%1) : (i32) -> ()
146-
ttg.warp_yield
147-
}
148-
// CHECK: partition0
149-
partition0() num_warps(4) {
150-
// 6*32 = 196
151-
152-
// CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32)
153-
// CHECK: [[C192:%.*]] = llvm.mlir.constant(192 : i32)
154-
// CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x
155-
// CHECK: [[REL_TIDX:%.*]] = llvm.sub [[TIDX]], [[C192]]
156-
// CHECK: [[ID:%.*]] = llvm.udiv [[REL_TIDX]], [[C32]]
157-
// CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]]
158-
%1 = ttg.warp_id
159-
// CHECK: "use"([[UNIFORM]])
160-
"use"(%1) : (i32) -> ()
161-
ttg.warp_return
162-
}
163-
partition1() num_warps(2) {
164-
// 4*32 = 128
165-
166-
// CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32)
167-
// CHECK: [[C128:%.*]] = llvm.mlir.constant(128 : i32)
168-
// CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x
169-
// CHECK: [[REL_TIDX:%.*]] = llvm.sub [[TIDX]], [[C128]]
170-
// CHECK: [[ID:%.*]] = llvm.udiv [[REL_TIDX]], [[C32]]
171-
// CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]]
172-
%1 = ttg.warp_id
173-
// CHECK: "use"([[UNIFORM]])
174-
"use"(%1) : (i32) -> ()
175-
ttg.warp_return
176-
} : () -> ()
177-
llvm.return
178-
}
179-
180-
}
181-
182-
// -----
183-
184122
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
185123

186124
// CHECK-LABEL: @one_warp

test/Conversion/warp_specialize_to_llvm.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,51 @@ llvm.func @partition_warpid_order() attributes {allocation.offset = 32 : i32} {
540540

541541
// -----
542542

543+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 18 : i32} {
544+
545+
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
546+
547+
// CHECK-LABEL: @warpid_warp_specialize
548+
llvm.func @warpid_warp_specialize() attributes {allocation.offset = 32 : i32} {
549+
// CHECK-DAG: [[C4:%.*]] = llvm.mlir.constant(4 : i32)
550+
// CHECK-DAG: [[C6:%.*]] = llvm.mlir.constant(6 : i32)
551+
552+
// Partition warp IDs are rewritten to be relative in this pass, while
553+
// keeping ttg.warp_id for NVGPUToLLVM to lower later.
554+
// CHECK: %{{.*}} = ttg.warp_id
555+
// CHECK-NEXT: [[REL0:%.*]] = llvm.sub %{{.*}}, [[C6]] : i32
556+
// CHECK-NEXT: "use"([[REL0]]) : (i32) -> ()
557+
558+
// CHECK: %{{.*}} = ttg.warp_id
559+
// CHECK-NEXT: [[REL1:%.*]] = llvm.sub %{{.*}}, [[C4]] : i32
560+
// CHECK-NEXT: "use"([[REL1]]) : (i32) -> ()
561+
562+
%0 = ttg.warp_id
563+
"use"(%0) : (i32) -> ()
564+
565+
ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array<i32: 6, 4>}
566+
default {
567+
%1 = ttg.warp_id
568+
"use"(%1) : (i32) -> ()
569+
ttg.warp_yield
570+
}
571+
partition0() num_warps(4) {
572+
%1 = ttg.warp_id
573+
"use"(%1) : (i32) -> ()
574+
ttg.warp_return
575+
}
576+
partition1() num_warps(2) {
577+
%1 = ttg.warp_id
578+
"use"(%1) : (i32) -> ()
579+
ttg.warp_return
580+
} : () -> ()
581+
llvm.return
582+
}
583+
584+
}
585+
586+
// -----
587+
543588
module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 12 : i32} {
544589

545590
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,8 @@ def make_llir(self, src, metadata, options, capability):
374374
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
375375
passes.ttgpuir.add_canonicalize_llvm_ir(pm)
376376
passes.common.add_cse(pm)
377-
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
378377
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
378+
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
379379
passes.common.add_canonicalizer(pm)
380380
passes.common.add_cse(pm)
381381
passes.common.add_symbol_dce(pm)

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,7 @@ class WarpIdOpPattern : public OpRewritePattern<mlir::triton::gpu::WarpIdOp> {
210210
return success();
211211
}
212212

213-
// If this is inside a warp specialize op, compute the relative thread ID
214-
// within the warp group.
215213
Value tid = NVVM::ThreadIdXOp::create(rewriter, loc, i32_ty);
216-
if (std::optional<int> startId =
217-
getWarpGroupStartThreadId(rewriter.getInsertionBlock()))
218-
tid = LLVM::SubOp::create(rewriter, loc, tid, b.i32_val(*startId));
219-
220214
Value warpId = b.udiv(tid, b.i32_val(32));
221215
if (!op.getOmitUniformHint()) {
222216
// This indicates to PTXAS that the result and its derived values are

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "mlir/IR/BuiltinOps.h"
88
#include "mlir/IR/ImplicitLocOpBuilder.h"
99
#include "mlir/Pass/PassManager.h"
10-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1110
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
1211
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
1312
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
@@ -83,6 +82,30 @@ class NVIDIAWarpSpecializeBarrierHelper : public WarpSpecializeBarrierHelper {
8382
unsigned numThreadsPerWarp;
8483
};
8584

85+
static void rewriteWarpSpecializeWarpIdsOnce(ModuleOp mod) {
86+
SmallVector<mlir::triton::gpu::WarpIdOp> wsWarpIds;
87+
mod.walk([&](mlir::triton::gpu::WarpIdOp op) {
88+
if (getWarpGroupStartWarpId(op->getBlock()))
89+
wsWarpIds.push_back(op);
90+
});
91+
92+
for (mlir::triton::gpu::WarpIdOp op : wsWarpIds) {
93+
std::optional<int> startWarpId = getWarpGroupStartWarpId(op->getBlock());
94+
assert(startWarpId &&
95+
"expected warp-specialize warp_id to have a start warp ID");
96+
97+
auto loc = op.getLoc();
98+
TritonLLVMIRRewriter b(loc, op);
99+
100+
// Keep `ttg.warp_id` for NVGPUToLLVM and only make it relative here.
101+
Value absWarpId =
102+
mlir::triton::gpu::WarpIdOp::create(b, loc, op.getOmitUniformHint());
103+
Value relWarpId =
104+
LLVM::SubOp::create(b, loc, absWarpId, b.i32_val(*startWarpId));
105+
b.replaceOp(op, relWarpId);
106+
}
107+
}
108+
86109
//===----------------------------------------------------------------------===//
87110
// lowerWarpSpecialize
88111
//===----------------------------------------------------------------------===//
@@ -249,6 +272,8 @@ struct ConvertWarpSpecializeToLLVM
249272
if (failed(runPipeline(pm, mod)))
250273
return signalPassFailure();
251274

275+
rewriteWarpSpecializeWarpIdsOnce(mod);
276+
252277
unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
253278
NVIDIAWarpSpecializeBarrierHelper barrierHelper(threadsPerWarp);
254279
if (failed(lowerWarpSpecializeBarriers(mod, barrierHelper)))

0 commit comments

Comments
 (0)