@@ -101,16 +101,19 @@ namespace {
101101//
102102class StreamPipeliner {
103103 // Define categories of scheduling details per Operation types.
104- // The StreamPipeliner schedules 4 types of operations:
105- // 1. GLOBAL_LOAD: tt.load
106- // 2. LOCAL_STORE: ttg.local_store (created by the StreamPipeliner)
107- // 3. LOCAL_LOAD: ttg.local_load (created by the StreamPipeliner)
104+ // The StreamPipeliner schedules 5 types of operations:
105+ // 1. GLOBAL_LOAD: tt.load / ttg.async_copy_global_to_local
106+ // 2. LOCAL_STORE: ttg.local_store
107+ // 3. LOCAL_LOAD: ttg.local_load
108108 // 4. COMPUTE: ops that use the loaded data
109+ // 5. ASYNC_WAIT: ttg.async_wait
110+ // Note that ttg ops mentioned in the above list are created in this pass.
109111 enum SchedType {
110112 SCHED_GLOBAL_LOAD,
111113 SCHED_LOCAL_STORE,
112114 SCHED_LOCAL_LOAD,
113115 SCHED_COMPUTE,
116+ SCHED_ASYNC_WAIT,
114117 SCHED_SIZE
115118 };
116119
@@ -125,6 +128,7 @@ class StreamPipeliner {
125128 stages[SCHED_LOCAL_STORE] = _globalPrefetch;
126129 stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch;
127130 stages[SCHED_COMPUTE] = lastStage;
131+ stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD];
128132
129133 options.supportDynamicLoops = true ;
130134 options.peelEpilogue = true ;
@@ -212,7 +216,6 @@ class StreamPipeliner {
212216// WARNING: Changing the order of schedule.clusters.newAtBack() calls
213217// can cause invalid schedules to be produced.
214218LogicalResult StreamPipeliner::initSchedule (int maxIndirectionLevel) {
215-
216219 bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0 ;
217220 stages[SCHED_LOCAL_STORE] += maxIndirectionLevel;
218221
@@ -221,6 +224,7 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
221224 << " , LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE]
222225 << " , LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD]
223226 << " , COMPUTE stage = " << stages[SCHED_COMPUTE]
227+ << " , ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT]
224228 << " ; total = " << numStages);
225229
226230 if (stages[SCHED_LOCAL_STORE] >= numStages ||
@@ -241,15 +245,19 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
241245
242246 LDBG (" deduced max shared memory buffer number = " << numBuffers);
243247
248+ // We place async wait as the first cluster because we want to have it being
249+ // the first in the main loop after pipelining.
250+ int asyncWaitCluster = 0 ;
251+
244252 // If tt.load and ttg.local_store are in the same stage
245253 // spread them apart to allow overlap with compute
246254 // else
247255 // Initiate ttg.local_store before tt.load
248- int globalLoadCluster = 0 ;
249- int localStoreCluster = 2 ;
256+ int globalLoadCluster = 1 ;
257+ int localStoreCluster = 3 ;
250258 if (!pairedGlobalLoadLocalStore) {
251- globalLoadCluster = 2 ;
252- localStoreCluster = 1 ;
259+ globalLoadCluster = 3 ;
260+ localStoreCluster = 2 ;
253261 }
254262
255263 // If ttg.local_load and ttg.local_store are in the same stage
@@ -260,33 +268,35 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
260268 // schedule ttg.local_load in the middle
261269 int localLoadCluster = globalLoadCluster;
262270 if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) {
263- localLoadCluster = std::max (2 , localStoreCluster + 1 );
271+ localLoadCluster = std::max (3 , localStoreCluster + 1 );
264272 } else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) {
265273 // For 1 buffer, ttg.local_load must occur before ttg.local_store
266274 localLoadCluster = localStoreCluster - 1 ;
267275 }
268276
269277 // Schedule compute with ttg.local_load if paired
270278 // otherwise, schedule in the middle
271- int computeCluster = 1 ;
279+ int computeCluster = 2 ;
272280 if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) {
273281 computeCluster = localLoadCluster;
274282 }
275283
276284 // Make assignments
277- std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec = {
278- schedule. clusters . newAtBack (), schedule. clusters . newAtBack (),
279- schedule. clusters . newAtBack (), schedule.clusters .newAtBack ()} ;
285+ std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec;
286+ std::generate (clusterVec. begin (), clusterVec. end (),
287+ [&]() { return schedule.clusters .newAtBack (); }) ;
280288
281289 clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster];
282290 clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster];
283291 clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster];
284292 clusters[SCHED_COMPUTE] = clusterVec[computeCluster];
293+ clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster];
285294
286295 LDBG (" Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster
287296 << " , LOCAL_STORE cluster = " << localStoreCluster
288297 << " , LOCAL_LOAD cluster = " << localLoadCluster
289298 << " , COMPUTE cluster = " << computeCluster
299+ << " , ASYNC_WAIT cluster = " << asyncWaitCluster
290300 << " ; total = " << SCHED_SIZE);
291301
292302 return success ();
@@ -333,30 +343,37 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc,
333343 for (auto alloc : allocsToErase)
334344 alloc.erase ();
335345
336- auto [stage, cluster] = schedule[loadOp];
337-
338- auto newLoadOp = builder.create <ttg::AsyncCopyGlobalToLocalOp>(
346+ auto copyOp = builder.create <ttg::AsyncCopyGlobalToLocalOp>(
339347 loadOp.getLoc (), src, viewLoad, loadOp.getMask (), loadOp.getOther (),
340348 loadOp.getCache (), loadOp.getEvict (), loadOp.getIsVolatile ());
341- schedule.erase (loadOp);
342- schedule.insert (newLoadOp, stage, cluster);
343349
344350 // Insert synchronization primitives to create barriers during lowering
345- auto commit =
346- builder.create <ttg::AsyncCommitGroupOp>(loc, newLoadOp->getResult (0 ));
347- ttg::AsyncWaitOp wait =
348- builder.create <ttg::AsyncWaitOp>(loc, commit->getResult (0 ), 0 );
349- // We need to place the prefetches (AsyncCopy) after the AsyncWaits which
350- // create a barrier to ensure all warps are finished reading the shared buffer
351- // we will write into. This is done by scheduling it as a local_store.
352- scheduleOp (newLoadOp, SCHED_LOCAL_STORE);
353- // Place ttg.async_commit_group op next to async load so the later
354- // UpdateAsyncWaitCount pass can deduce better waitcnts
355- scheduleOp (commit, SCHED_LOCAL_STORE);
351+ auto commitOp =
352+ builder.create <ttg::AsyncCommitGroupOp>(loc, copyOp->getResult (0 ));
353+
354+ ttg::AsyncWaitOp waitOp =
355+ builder.create <ttg::AsyncWaitOp>(loc, commitOp->getResult (0 ), 0 );
356356
357357 // Create local load which consumes the async token from the AsyncWait
358358 auto sharedLoad =
359- builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad, wait);
359+ builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad, waitOp);
360+
361+ auto [loadStage, loadCluster] = schedule[loadOp];
362+ schedule.erase (loadOp);
363+ // Schedule new ops
364+ schedule.insert (copyOp, loadStage, loadCluster);
365+ // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the
366+ // later UpdateAsyncWaitCount pass can deduce better waitcnts
367+ schedule.insert (commitOp, loadStage, loadCluster);
368+ // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to
369+ // place the AsyncCopy prefetches after the AsyncWaits which create a barrier
370+ // to ensure all warps are finished reading the shared buffer we will write
371+ // into. This is done by scheduling AsyncWait as the first cluster.
372+ // If AsyncCopy and LocalLoads are in the same stage we do not assign a
373+ // schdule so they are placed before the LocalLoads
374+ if (loadStage != stages[SCHED_LOCAL_LOAD])
375+ scheduleOp (waitOp, SCHED_ASYNC_WAIT);
376+
360377 if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
361378 scheduleOp (sharedLoad, SCHED_LOCAL_LOAD);
362379
0 commit comments