Skip to content

Commit 5b856e6

Browse files
authored
[AMD] Replace ReorderInstructions with MoveUpPrologueLoads (#9328)
After the recent changes, the ReorderInstructions pass had only one optimization left: moving prologue loads early for prefetching. Add a new pass for that optimization, refactor the implementation, and add more tests.
1 parent 72d0d90 commit 5b856e6

9 files changed

Lines changed: 262 additions & 321 deletions

File tree

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
114114
mlir::registerTritonAMDGPUHoistLayoutConversions();
115115
mlir::registerTritonAMDGPUSinkLayoutConversions();
116116
mlir::registerTritonAMDGPUPrepareIfCombining();
117-
mlir::registerTritonAMDGPUReorderInstructions();
117+
mlir::registerTritonAMDGPUMoveUpPrologueLoads();
118118
mlir::registerTritonAMDGPUBlockPingpong();
119119
mlir::registerTritonAMDGPUPipeline();
120120
mlir::registerTritonAMDGPUScheduleLoops();
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-move-up-prologue-loads | FileCheck %s
2+
3+
// CHECK-LABEL: move_up_slice
4+
// CHECK: arith.cmpi
5+
// CHECK: tt.splat
6+
// CHECK: tt.load
7+
// CHECK: ttg.local_alloc
8+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
9+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
10+
#smem = #ttg.shared_memory
11+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
12+
tt.func @move_up_slice(%arg0: tensor<32x128x!tt.ptr<f16>, #blocked>, %arg1: i32) {
13+
%c0_i32 = arith.constant 0 : i32
14+
%0 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
15+
%1 = arith.cmpi sgt, %arg1, %c0_i32 : i32
16+
%2 = tt.splat %1 : i1 -> tensor<32x128xi1, #blocked>
17+
%3 = tt.load %arg0, %2 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr<f16>, #blocked>
18+
tt.return
19+
}
20+
}
21+
22+
// -----
23+
24+
// CHECK-LABEL: keep_load_order
25+
// CHECK: arith.cmpi sgt
26+
// CHECK: tt.splat
27+
// CHECK: tt.load %arg0
28+
// CHECK: tt.addptr
29+
// CHECK: arith.cmpi slt
30+
// CHECK: tt.splat
31+
// CHECK: tt.load
32+
// CHECK: ttg.local_alloc
33+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
34+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
35+
#smem = #ttg.shared_memory
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
37+
tt.func @keep_load_order(%arg0: tensor<32x128x!tt.ptr<f16>, #blocked>, %arg1: i32, %arg2: i32) {
38+
%c0_i32 = arith.constant 0 : i32
39+
%cst = arith.constant dense<128> : tensor<32x128xi32, #blocked>
40+
%0 = tt.addptr %arg0, %cst : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
41+
%1 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
42+
%2 = arith.cmpi sgt, %arg1, %c0_i32 : i32
43+
%3 = tt.splat %2 : i1 -> tensor<32x128xi1, #blocked>
44+
%4 = tt.load %arg0, %3 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr<f16>, #blocked>
45+
%5 = arith.cmpi slt, %arg2, %c0_i32 : i32
46+
%6 = tt.splat %5 : i1 -> tensor<32x128xi1, #blocked>
47+
%7 = tt.load %0, %6 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr<f16>, #blocked>
48+
tt.return
49+
}
50+
}
51+
52+
// -----
53+
54+
// CHECK-LABEL: break_at_atomic
55+
// CHECK: tt.atomic_rmw
56+
// CHECK: arith.cmpi
57+
// CHECK: tt.splat
58+
// CHECK: tt.load
59+
// CHECK: ttg.local_alloc
60+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
61+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
62+
#smem = #ttg.shared_memory
63+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
64+
tt.func @break_at_atomic(%arg0: tensor<32x128x!tt.ptr<f16>, #blocked>, %arg1: i32, %arg2: !tt.ptr<i32>) {
65+
%c0_i32 = arith.constant 0 : i32
66+
%c1_i32 = arith.constant 1 : i32
67+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg2, %c1_i32 : (!tt.ptr<i32>, i32) -> i32
68+
%1 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
69+
%2 = arith.cmpi sgt, %arg1, %c0_i32 : i32
70+
%3 = tt.splat %2 : i1 -> tensor<32x128xi1, #blocked>
71+
%4 = tt.load %arg0, %3 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr<f16>, #blocked>
72+
tt.return
73+
}
74+
}
75+
76+
// -----
77+
78+
// CHECK-LABEL: break_at_barrier
79+
// CHECK: gpu.barrier
80+
// CHECK: arith.cmpi
81+
// CHECK: tt.splat
82+
// CHECK: tt.load
83+
// CHECK: ttg.local_alloc
84+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
85+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
86+
#smem = #ttg.shared_memory
87+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
88+
tt.func @break_at_barrier(%arg0: tensor<32x128x!tt.ptr<f16>, #blocked>, %arg1: i32) {
89+
%c0_i32 = arith.constant 0 : i32
90+
gpu.barrier
91+
%0 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
92+
%1 = arith.cmpi sgt, %arg1, %c0_i32 : i32
93+
%2 = tt.splat %1 : i1 -> tensor<32x128xi1, #blocked>
94+
%3 = tt.load %arg0, %2 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr<f16>, #blocked>
95+
tt.return
96+
}
97+
}
98+
99+
// -----
100+
101+
// CHECK-LABEL: break_at_loop
102+
// CHECK: scf.for
103+
// CHECK: tt.load %arg0
104+
// CHECK: ttg.local_alloc
105+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
106+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
107+
#smem = #ttg.shared_memory
108+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
109+
tt.func @break_at_loop(%arg0: tensor<32x128x!tt.ptr<f16>, #blocked>, %arg1: i32) {
110+
%c0_i32 = arith.constant 0 : i32
111+
%c1_i32 = arith.constant 1 : i32
112+
scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 {
113+
}
114+
%0 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
115+
%1 = tt.load %arg0 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr<f16>, #blocked>
116+
tt.return
117+
}
118+
}
119+
120+
// -----
121+
122+
// Negative test: load without amd.pipeliner_part attribute should not be moved
123+
// CHECK-LABEL: no_prologue_attribute
124+
// CHECK: ttg.local_alloc
125+
// CHECK: arith.cmpi
126+
// CHECK: tt.splat
127+
// CHECK: tt.load
128+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
129+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
130+
#smem = #ttg.shared_memory
131+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
132+
tt.func @no_prologue_attribute(%arg0: tensor<32x128x!tt.ptr<f16>, #blocked>, %arg1: i32) {
133+
%c0_i32 = arith.constant 0 : i32
134+
%0 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared, #smem, mutable>
135+
%1 = arith.cmpi sgt, %arg1, %c0_i32 : i32
136+
%2 = tt.splat %1 : i1 -> tensor<32x128xi1, #blocked>
137+
%3 = tt.load %arg0, %2 : tensor<32x128x!tt.ptr<f16>, #blocked>
138+
tt.return
139+
}
140+
}

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 0 additions & 176 deletions
This file was deleted.

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def make_ttgir(mod, metadata, options):
251251
if is_in_thread_transpose_enabled(options.arch):
252252
amd.passes.ttgpuir.add_in_thread_transpose(pm)
253253
passes.ttgpuir.add_remove_layout_conversions(pm)
254-
amd.passes.ttgpuir.add_reorder_instructions(pm)
254+
amd.passes.ttgpuir.add_move_up_prologue_loads(pm)
255255
if use_block_pingpong and options.num_stages > 1:
256256
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
257257

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,14 @@ def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers"
154154
];
155155
}
156156

157-
def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> {
158-
let summary = "Reorder instructions";
159-
160-
let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving "
161-
"conversions from shared memory before their first use) and (2) promote LLVM instruction "
162-
"order more friendly to `ptxas`.";
157+
def TritonAMDGPUMoveUpPrologueLoads
158+
: Pass<"tritonamdgpu-move-up-prologue-loads", "mlir::triton::FuncOp"> {
159+
let summary = "Move up global loads in prologue for better GEMM performance";
160+
161+
let description =
162+
"This pass moves global load ops early to prefetch in the prologue. "
163+
"This may increase register pressure but it enables issuing global loads "
164+
"early.";
163165

164166
let dependentDialects = [];
165167
}

third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ add_triton_library(TritonAMDGPUTransforms
1010
HoistLayoutConversions.cpp
1111
SinkLayoutConversions.cpp
1212
PrepareIfCombining.cpp
13-
ReorderInstructions.cpp
13+
MoveUpPrologueLoads.cpp
1414
Pipeline.cpp
1515
ScheduleLoops.cpp
1616
LowerLoops.cpp

0 commit comments

Comments
 (0)