Skip to content

Commit f171598

Browse files
authored
[AMD] Support warp specialization on gfx1250 (triton-lang#8947)
Add warp specialization support for AMD gfx1250.
1 parent 278e956 commit f171598

File tree

15 files changed

+1433
-124
lines changed

15 files changed

+1433
-124
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9797

9898
// TritonAMDGPUToLLVM passes
9999
mlir::triton::registerAllocateAMDGPUSharedMemory();
100+
mlir::triton::registerTritonAMDGPUConvertWarpSpecializeToLLVM();
100101
mlir::triton::registerConvertTritonAMDGPUToLLVM();
101102
mlir::triton::registerConvertBuiltinFuncToLLVM();
102103
mlir::triton::registerConvertWarpPipeline();

python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def arrive(mbarrier, *, count=1, _semantic=None):
5353
Arrive at an mbarrier with a specified count. The operation requires a `count` attribute
5454
of at least 1, and decreases the pending arrival count of the mbarrier by the specific count.
5555
If the pending count reaches zero, the phase changes (is decremented in a wraparound manner) and the
56-
pending count is reloaded with the init count value. Returns the mbarrier's phase prior to the "arrive" operation.
56+
pending count is reloaded with the init count value. Returns the mbarrier's phase parity (0 for even, 1 for odd) prior to the "arrive" operation.
5757
5858
Args:
5959
mbarrier (shared_memory_descriptor): Barrier to be signalled.

test/Conversion/warp_specialize_to_llvm.mlir

Lines changed: 393 additions & 91 deletions
Large diffs are not rendered by default.

third_party/amd/backend/compiler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def gluon_to_ttgir(src, metadata, options):
278278
passes.gluon.add_canonicalizer(pm)
279279
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
280280
amd.passes.ttgpuir.add_warp_pipeline(pm)
281+
passes.ttgpuir.add_allocate_warp_groups(pm)
281282

282283
pm.run(mod, 'gluon_to_ttgir')
283284
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
@@ -308,6 +309,7 @@ def make_llir(src, metadata, options):
308309
## For now it is used as a controller for developers only.
309310
__HIP_FTZ = True
310311
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
312+
amd.passes.ttgpuir.add_warp_specialize_to_llvm(pm, options.arch)
311313
passes.common.add_canonicalizer(pm)
312314
passes.common.add_cse(pm)
313315

@@ -371,7 +373,12 @@ def make_llir(src, metadata, options):
371373
fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()]
372374
# The public kernel should be kernel 0.
373375
fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL)
374-
fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}")
376+
# warp-specialization mutates num_warps
377+
total_warps_num = options.num_warps
378+
total_num_warps = src.get_int_attr("ttg.total-num-warps")
379+
if total_num_warps is not None:
380+
total_warps_num = total_num_warps
381+
fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{total_warps_num*options.warp_size}")
375382
if "memory-bound-attention" in options.schedule_hint.split(','):
376383
fns[0].add_fn_attr("amdgpu-sched-strategy", "iterative-ilp")
377384
fns[0].add_fn_attr("uniform-work-group-size", "true")
@@ -425,6 +432,7 @@ def make_llir(src, metadata, options):
425432
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
426433

427434
# Get some metadata
435+
metadata["num_warps"] = total_warps_num
428436
metadata["shared"] = src.get_int_attr("ttg.shared")
429437
metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
430438
metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def ArriveBarrierOp : TT_AMDGPU_Op<"arrive_barrier"> {
893893
Performs the "arrive" operation on an mbarrier object in shared memory. The operation requires a `count` attribute
894894
of at least 1, and decreases the pending arrival count of the mbarrier by the specific count. If the pending count reaches
895895
zero, the phase changes (is decremented in a wraparound manner) and the pending count is reloaded with the init count value. Returns the phase
896-
of the mbarrier object prior to the "arrive" operation.
896+
parity (0 for even, 1 for odd) of the mbarrier object prior to the "arrive" operation.
897897

898898
Example:
899899

third_party/amd/include/TritonAMDGPUToLLVM/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ namespace mlir::triton {
2424
} // namespace mlir::triton
2525

2626
namespace mlir::triton::AMD {
27+
2728
std::unique_ptr<OperationPass<ModuleOp>> createConvertWarpPipelinePass();
29+
std::unique_ptr<OperationPass<ModuleOp>>
30+
createTritonAMDGPUConvertWarpSpecializeToLLVMPass(StringRef arch);
2831
void runScalarizePackedFOpsPass(llvm::Function &F);
2932

3033
} // namespace mlir::triton::AMD

third_party/amd/include/TritonAMDGPUToLLVM/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,24 @@ def ConvertWarpPipeline : Pass<"convert-warp-pipeline", "mlir::ModuleOp"> {
8484
"mlir::gpu::GPUDialect",
8585
"mlir::ROCDL::ROCDLDialect",
8686
"mlir::triton::amdgpu::TritonAMDGPUDialect"];
87+
}
88+
89+
def TritonAMDGPUConvertWarpSpecializeToLLVM : Pass<"triton-amdgpu-convert-warp-specialize-to-llvm", "mlir::ModuleOp"> {
90+
let summary = "lower `ttg.warp_specialize` to LLVM";
91+
let constructor = "mlir::triton::AMD::createTritonAMDGPUConvertWarpSpecializeToLLVMPass(\"\")";
92+
let description = [{
93+
The `triton-amdgpu-convert-warp-specialize-to-llvm` pass performs codegen for warp
94+
specialization. It is a function-level transformation that rewrites
95+
warp-specialized kernels by using shared memory and barriers to communicate
96+
states between the default warpgroup and the worker warps.
97+
}];
98+
99+
let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::ROCDL::ROCDLDialect"];
87100

101+
let options = [
102+
Option<"arch", "arch", "std::string", /*default*/"\"\"",
103+
"target device architecture, e.g., gfx1250">,
104+
];
88105
}
89106

90107
#endif

third_party/amd/lib/TritonAMDGPUToLLVM/BarrierOpToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ using namespace mlir;
99
using namespace mlir::triton;
1010

1111
constexpr int kBarrierCountBitWidth = 29;
12-
constexpr int kBarrierPhaseMask = ((1ULL << (32 - kBarrierCountBitWidth)) - 1);
12+
// NOTE: We only care for the parity of the phase (0: even, 1: odd), so use 1
13+
// bit constexpr int kBarrierPhaseMask = ((1ULL << (32 - kBarrierCountBitWidth))
14+
// - 1);
15+
constexpr int kBarrierPhaseMask = 1;
1316
constexpr int kInitCountPos = 32;
1417

1518
namespace {

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_triton_library(TritonAMDGPUToLLVM
66
TensorPtrOpsToLLVM.cpp
77
ConvertLayoutOpToLLVM.cpp
88
ConvertWarpPipeline.cpp
9+
ConvertWarpSpecializeToLLVM.cpp
910
MemoryOpToLLVM.cpp
1011
MaskedOpsToLLVM.cpp
1112
DotOpToLLVM/FMA.cpp
@@ -35,6 +36,7 @@ add_triton_library(TritonAMDGPUToLLVM
3536
LLVMIRIncGen
3637

3738
LINK_LIBS PUBLIC
39+
MLIRReconcileUnrealizedCasts
3840
TritonGPUToLLVM
3941
TritonAMDGPUIR
4042
LLVMCore

0 commit comments

Comments
 (0)