@@ -328,6 +328,22 @@ LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() {
328
328
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget ().getDefiningOp ());
329
329
};
330
330
331
+ // ===----------------------------------------------------------------------===//
332
+ // AMDAIE_LogicalObjectFifoAccessOp
333
+ // ===----------------------------------------------------------------------===//
334
+
335
+ void LogicalObjectFifoAccessOp::build (OpBuilder &b,
336
+ mlir::OperationState &result, Value input,
337
+ MemoryAccess accessType) {
338
+ auto type = llvm::cast<LogicalObjectFifoType>(input.getType ());
339
+ build (b, result, type.getElementType (), input, accessType);
340
+ }
341
+
342
+ LogicalObjectFifoFromMemrefOp
343
+ LogicalObjectFifoAccessOp::getLogicalObjectFifo () {
344
+ return dyn_cast<LogicalObjectFifoFromMemrefOp>(getInput ().getDefiningOp ());
345
+ };
346
+
331
347
// ===----------------------------------------------------------------------===//
332
348
// AMDAIE_LogicalObjectFifoAcquire
333
349
// ===----------------------------------------------------------------------===//
@@ -341,6 +357,26 @@ void LogicalObjectFifoAcquire::build(OpBuilder &b, mlir::OperationState &result,
341
357
// AMDAIE_LogicalObjectFifoFromMemrefOp
342
358
// ===----------------------------------------------------------------------===//
343
359
360
+ // / Build with an array of static tile locations.
361
+ void LogicalObjectFifoFromMemrefOp::build (
362
+ OpBuilder &b, mlir::OperationState &result, Value memref,
363
+ ArrayRef<std::pair<int64_t , int64_t >> tileLocations) {
364
+ SmallVector<Value> tiles;
365
+ tiles.reserve (tileLocations.size ());
366
+ for (auto [column, row] : tileLocations) {
367
+ auto colIndex = b.create <arith::ConstantIndexOp>(b.getUnknownLoc (), column);
368
+ auto rowIndex = b.create <arith::ConstantIndexOp>(b.getUnknownLoc (), row);
369
+ auto tileOp =
370
+ b.create <AMDAIE::TileOp>(b.getUnknownLoc (), colIndex, rowIndex);
371
+ tiles.push_back (tileOp.getResult ());
372
+ }
373
+ // For deterministic order.
374
+ llvm::sort (tiles.begin (), tiles.end (),
375
+ TileOp::tileValueColumnAndRowComparator);
376
+ auto type = LogicalObjectFifoType::get (cast<MemRefType>(memref.getType ()));
377
+ build (b, result, type, memref, tiles);
378
+ }
379
+
344
380
LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize (
345
381
LogicalObjectFifoFromMemrefOp logicalObjectFifo,
346
382
PatternRewriter &rewriter) {
@@ -349,23 +385,19 @@ LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize(
349
385
return success ();
350
386
}
351
387
352
- auto comparator = [](Value a, Value b) -> bool {
353
- TileOp tileA = dyn_cast<TileOp>(a.getDefiningOp ());
354
- TileOp tileB = dyn_cast<TileOp>(b.getDefiningOp ());
355
- int64_t colA = getConstantIntValue (tileA.getCol ()).value ();
356
- int64_t rowA = getConstantIntValue (tileA.getRow ()).value ();
357
- int64_t colB = getConstantIntValue (tileB.getCol ()).value ();
358
- int64_t rowB = getConstantIntValue (tileB.getRow ()).value ();
359
- if (colA == colB) return rowA < rowB;
360
- return colA < colB;
361
- };
362
388
SmallVector<Value> tiles = logicalObjectFifo.getTiles ();
363
- if (llvm::is_sorted (tiles, comparator)) {
389
+ if (llvm::is_sorted (tiles, TileOp::tileValueColumnAndRowComparator)) {
390
+ // Still erase duplicates.
391
+ tiles.erase (std::unique (tiles.begin (), tiles.end ()), tiles.end ());
364
392
return success ();
365
393
}
366
394
367
- // If tiles are not sorted, sort them and replace the logical objectfifo
368
- llvm::sort (tiles.begin (), tiles.end (), comparator);
395
+ // If tiles are not sorted, sort them, erase duplicates and replace the
396
+ // logical objectfifo.
397
+ llvm::sort (tiles.begin (), tiles.end (),
398
+ TileOp::tileValueColumnAndRowComparator);
399
+ tiles.erase (std::unique (tiles.begin (), tiles.end ()), tiles.end ());
400
+
369
401
rewriter.replaceOpWithNewOp <AMDAIE::LogicalObjectFifoFromMemrefOp>(
370
402
logicalObjectFifo,
371
403
llvm::cast<LogicalObjectFifoType>(
@@ -532,6 +564,23 @@ bool TileOp::hasStaticLocation() {
532
564
return getConstantIntValue (getCol ()) && getConstantIntValue (getRow ());
533
565
}
534
566
567
+ bool TileOp::tileColumnComparator (AMDAIE::TileOp &a, AMDAIE::TileOp &b) {
568
+ int64_t colA = getConstantIntValue (a.getCol ()).value ();
569
+ int64_t colB = getConstantIntValue (b.getCol ()).value ();
570
+ return colA < colB;
571
+ }
572
+
573
+ bool TileOp::tileValueColumnAndRowComparator (Value a, Value b) {
574
+ TileOp tileA = dyn_cast<AMDAIE::TileOp>(a.getDefiningOp ());
575
+ TileOp tileB = dyn_cast<AMDAIE::TileOp>(b.getDefiningOp ());
576
+ int64_t colA = getConstantIntValue (tileA.getCol ()).value ();
577
+ int64_t rowA = getConstantIntValue (tileA.getRow ()).value ();
578
+ int64_t colB = getConstantIntValue (tileB.getCol ()).value ();
579
+ int64_t rowB = getConstantIntValue (tileB.getRow ()).value ();
580
+ if (colA == colB) return rowA < rowB;
581
+ return colA < colB;
582
+ };
583
+
535
584
// ===----------------------------------------------------------------------===//
536
585
// AMDAIE_WorkgroupOp
537
586
// ===----------------------------------------------------------------------===//
0 commit comments