Skip to content

Commit 3932686

Browse files
AlexAUTantiagainst
andauthored
[AMD] Schedule AsyncWait in front of AsyncCopy and LocalLoad (#6621)
Changes the scheduling to place `AsyncWait` in a separate cluster at the start of the schedule if `GlobalLoadStage != LocalLoadStage`. This allows us to reorder the `AsyncCopy` and `LocalLoads` more freely because the `AsyncWait` will always be before both `Ops`. --------- Co-authored-by: Lei Zhang <antiagainst@gmail.com>
1 parent 415af4b commit 3932686

2 files changed

Lines changed: 53 additions & 34 deletions

File tree

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
4646
// ASYNC: ttg.async_copy_global_to_local
4747
// ASYNC: scf.for
4848
// ASYNC: ttg.async_wait
49+
// ASYNC: ttg.async_copy_global_to_local
4950
// ASYNC: tt.dot
5051
// ASYNC: tt.dot
51-
// ASYNC: ttg.async_copy_global_to_local
5252
// ASYNC: scf.yield
5353
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
5454
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
@@ -500,9 +500,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
500500
// SYNC: scf.yield
501501
//
502502
// ASYNC: ttg.async_wait
503-
// ASYNC-COUNT-2: ttg.local_load
503+
// ASYNC: ttg.async_copy_global_to_local
504+
// ASYNC: ttg.local_load
505+
// ASYNC: ttg.async_copy_global_to_local
506+
// ASYNC: ttg.local_load
504507
// ASYNC: ttg.dot
505-
// ASYNC-COUNT-2: ttg.async_copy_global_to_local
506508

507509
// Epilogue
508510
// ASYNC: ttg.async_wait

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,19 @@ namespace {
101101
//
102102
class StreamPipeliner {
103103
// Define categories of scheduling details per Operation types.
104-
// The StreamPipeliner schedules 4 types of operations:
105-
// 1. GLOBAL_LOAD: tt.load
106-
// 2. LOCAL_STORE: ttg.local_store (created by the StreamPipeliner)
107-
// 3. LOCAL_LOAD: ttg.local_load (created by the StreamPipeliner)
104+
// The StreamPipeliner schedules 5 types of operations:
105+
// 1. GLOBAL_LOAD: tt.load / ttg.async_copy_global_to_local
106+
// 2. LOCAL_STORE: ttg.local_store
107+
// 3. LOCAL_LOAD: ttg.local_load
108108
// 4. COMPUTE: ops that use the loaded data
109+
// 5. ASYNC_WAIT: ttg.async_wait
110+
// Note that ttg ops mentioned in the above list are created in this pass.
109111
enum SchedType {
110112
SCHED_GLOBAL_LOAD,
111113
SCHED_LOCAL_STORE,
112114
SCHED_LOCAL_LOAD,
113115
SCHED_COMPUTE,
116+
SCHED_ASYNC_WAIT,
114117
SCHED_SIZE
115118
};
116119

@@ -125,6 +128,7 @@ class StreamPipeliner {
125128
stages[SCHED_LOCAL_STORE] = _globalPrefetch;
126129
stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch;
127130
stages[SCHED_COMPUTE] = lastStage;
131+
stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD];
128132

129133
options.supportDynamicLoops = true;
130134
options.peelEpilogue = true;
@@ -212,7 +216,6 @@ class StreamPipeliner {
212216
// WARNING: Changing the order of schedule.clusters.newAtBack() calls
213217
// can cause invalid schedules to be produced.
214218
LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
215-
216219
bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0;
217220
stages[SCHED_LOCAL_STORE] += maxIndirectionLevel;
218221

@@ -221,6 +224,7 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
221224
<< ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE]
222225
<< ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD]
223226
<< ", COMPUTE stage = " << stages[SCHED_COMPUTE]
227+
<< ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT]
224228
<< "; total = " << numStages);
225229

226230
if (stages[SCHED_LOCAL_STORE] >= numStages ||
@@ -241,15 +245,19 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
241245

242246
LDBG("deduced max shared memory buffer number = " << numBuffers);
243247

248+
// We place async wait as the first cluster because we want to have it being
249+
// the first in the main loop after pipelining.
250+
int asyncWaitCluster = 0;
251+
244252
// If tt.load and ttg.local_store are in the same stage
245253
// spread them apart to allow overlap with compute
246254
// else
247255
// Initiate ttg.local_store before tt.load
248-
int globalLoadCluster = 0;
249-
int localStoreCluster = 2;
256+
int globalLoadCluster = 1;
257+
int localStoreCluster = 3;
250258
if (!pairedGlobalLoadLocalStore) {
251-
globalLoadCluster = 2;
252-
localStoreCluster = 1;
259+
globalLoadCluster = 3;
260+
localStoreCluster = 2;
253261
}
254262

255263
// If ttg.local_load and ttg.local_store are in the same stage
@@ -260,33 +268,35 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
260268
// schedule ttg.local_load in the middle
261269
int localLoadCluster = globalLoadCluster;
262270
if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) {
263-
localLoadCluster = std::max(2, localStoreCluster + 1);
271+
localLoadCluster = std::max(3, localStoreCluster + 1);
264272
} else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) {
265273
// For 1 buffer, ttg.local_load must occur before ttg.local_store
266274
localLoadCluster = localStoreCluster - 1;
267275
}
268276

269277
// Schedule compute with ttg.local_load if paired
270278
// otherwise, schedule in the middle
271-
int computeCluster = 1;
279+
int computeCluster = 2;
272280
if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) {
273281
computeCluster = localLoadCluster;
274282
}
275283

276284
// Make assignments
277-
std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec = {
278-
schedule.clusters.newAtBack(), schedule.clusters.newAtBack(),
279-
schedule.clusters.newAtBack(), schedule.clusters.newAtBack()};
285+
std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec;
286+
std::generate(clusterVec.begin(), clusterVec.end(),
287+
[&]() { return schedule.clusters.newAtBack(); });
280288

281289
clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster];
282290
clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster];
283291
clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster];
284292
clusters[SCHED_COMPUTE] = clusterVec[computeCluster];
293+
clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster];
285294

286295
LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster
287296
<< ", LOCAL_STORE cluster = " << localStoreCluster
288297
<< ", LOCAL_LOAD cluster = " << localLoadCluster
289298
<< ", COMPUTE cluster = " << computeCluster
299+
<< ", ASYNC_WAIT cluster = " << asyncWaitCluster
290300
<< "; total = " << SCHED_SIZE);
291301

292302
return success();
@@ -333,30 +343,37 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc,
333343
for (auto alloc : allocsToErase)
334344
alloc.erase();
335345

336-
auto [stage, cluster] = schedule[loadOp];
337-
338-
auto newLoadOp = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
346+
auto copyOp = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
339347
loadOp.getLoc(), src, viewLoad, loadOp.getMask(), loadOp.getOther(),
340348
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
341-
schedule.erase(loadOp);
342-
schedule.insert(newLoadOp, stage, cluster);
343349

344350
// Insert synchronization primitives to create barriers during lowering
345-
auto commit =
346-
builder.create<ttg::AsyncCommitGroupOp>(loc, newLoadOp->getResult(0));
347-
ttg::AsyncWaitOp wait =
348-
builder.create<ttg::AsyncWaitOp>(loc, commit->getResult(0), 0);
349-
// We need to place the prefetches (AsyncCopy) after the AsyncWaits which
350-
// create a barrier to ensure all warps are finished reading the shared buffer
351-
// we will write into. This is done by scheduling it as a local_store.
352-
scheduleOp(newLoadOp, SCHED_LOCAL_STORE);
353-
// Place ttg.async_commit_group op next to async load so the later
354-
// UpdateAsyncWaitCount pass can deduce better waitcnts
355-
scheduleOp(commit, SCHED_LOCAL_STORE);
351+
auto commitOp =
352+
builder.create<ttg::AsyncCommitGroupOp>(loc, copyOp->getResult(0));
353+
354+
ttg::AsyncWaitOp waitOp =
355+
builder.create<ttg::AsyncWaitOp>(loc, commitOp->getResult(0), 0);
356356

357357
// Create local load which consumes the async token from the AsyncWait
358358
auto sharedLoad =
359-
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad, wait);
359+
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad, waitOp);
360+
361+
auto [loadStage, loadCluster] = schedule[loadOp];
362+
schedule.erase(loadOp);
363+
// Schedule new ops
364+
schedule.insert(copyOp, loadStage, loadCluster);
365+
// Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the
366+
// later UpdateAsyncWaitCount pass can deduce better waitcnts
367+
schedule.insert(commitOp, loadStage, loadCluster);
368+
// If the LocalLoads are scheduled to a later stage than AsyncCopy we need to
369+
// place the AsyncCopy prefetches after the AsyncWaits which create a barrier
370+
// to ensure all warps are finished reading the shared buffer we will write
371+
// into. This is done by scheduling AsyncWait as the first cluster.
372+
// If AsyncCopy and LocalLoads are in the same stage we do not assign a
373+
// schdule so they are placed before the LocalLoads
374+
if (loadStage != stages[SCHED_LOCAL_LOAD])
375+
scheduleOp(waitOp, SCHED_ASYNC_WAIT);
376+
360377
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
361378
scheduleOp(sharedLoad, SCHED_LOCAL_LOAD);
362379

0 commit comments

Comments
 (0)