@@ -462,18 +462,29 @@ struct FatPointers {
462462
463463 friend bool operator ==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
464464 return lhs.canNarrow == rhs.canNarrow &&
465- lhs.attributes == rhs.attributes &&
466- lhs.smallTensorBase == rhs.smallTensorBase ;
465+ lhs.isSmallTensor == rhs.isSmallTensor &&
466+ lhs.attributes . getArrayRef () == rhs.attributes . getArrayRef () ;
467467 }
468468
469469 friend bool operator !=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
470470 return !(lhs == rhs);
471471 }
472472
473- llvm::DenseMap<StringRef, Attribute> attributes;
474- // If the fat-pointer points to somewhere in a small-tensor, keep track the
475- // base of the tensor.
476- Value smallTensorBase;
473+ static FatPtrAttrs intersect (const FatPtrAttrs &lhs,
474+ const FatPtrAttrs &rhs) {
475+ FatPtrAttrs result;
476+ result.canNarrow = lhs.canNarrow && rhs.canNarrow ;
477+ result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor ;
478+ for (const auto &attr : lhs.attributes ) {
479+ auto it = rhs.attributes .find (attr.first );
480+ if (it != rhs.attributes .end () && it->second == attr.second )
481+ result.attributes [attr.first ] = attr.second ;
482+ }
483+ return result;
484+ }
485+
486+ llvm::SmallMapVector<StringRef, Attribute, 2 > attributes;
487+ bool isSmallTensor = false ;
477488 bool canNarrow = false ;
478489 };
479490
@@ -563,7 +574,7 @@ Value createTensorPointer(RewriterBase &rewriter, Value basePtr, Value offset,
563574 auto addPtrOp =
564575 tt::AddPtrOp::create (rewriter, loc, basePtr.getType (), basePtr, offset);
565576 for (const auto &attribute : fatPtrAttrs.attributes )
566- addPtrOp->setAttr (attribute.getFirst () , attribute.getSecond () );
577+ addPtrOp->setAttr (attribute.first , attribute.second );
567578 return addPtrOp.getResult ();
568579 }
569580
@@ -585,7 +596,7 @@ Value createTensorPointer(RewriterBase &rewriter, Value basePtr, Value offset,
585596 tt::AddPtrOp::create (rewriter, loc, tensorPtrType, tensorPtr, offset);
586597
587598 for (const auto &attribute : fatPtrAttrs.attributes )
588- addPtrOp->setAttr (attribute.getFirst () , attribute.getSecond () );
599+ addPtrOp->setAttr (attribute.first , attribute.second );
589600 return addPtrOp.getResult ();
590601}
591602
@@ -745,7 +756,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
745756 RewriterBase::InsertionGuard guard (rewriter);
746757 rewriter.setInsertionPoint (addPtrOp);
747758
748- if (fatPtrs.at ({fatPtrBase, fatPtrOffset}).smallTensorBase )
759+ if (fatPtrs.at ({fatPtrBase, fatPtrOffset}).isSmallTensor )
749760 return rewriteSmallTensorPtr (addPtrOp, adaptor, rewriter);
750761
751762 // Query all discardable attributes that we want to preserve
@@ -861,7 +872,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
861872 const auto &oldAttr = fatPtrs.at ({fatPtrBase, fatPtrOffset});
862873
863874 LDBG (" smal-tensor addPtr: " << addPtrOp);
864- LDBG (" - with tensor-base : " << oldAttr.smallTensorBase );
875+ LDBG (" - isSmallTensor : " << oldAttr.isSmallTensor );
865876 LDBG (" - with originl offset: " << origOffset);
866877 LDBG (" - fatPtr base: " << fatPtrBase);
867878 LDBG (" - fatPtr offst: " << fatPtrOffset);
@@ -1362,17 +1373,6 @@ class ConvertArithSelectOp
13621373 // select of base and offset
13631374 ValueRange fatPtrFalse = adaptor.getFalseValue ();
13641375 ValueRange fatPtrTrue = adaptor.getTrueValue ();
1365- // Simple case of a scalar select: update the base pointer
1366- if (!isa<RankedTensorType>(selectOp.getType ())) {
1367- auto newSelectOp = arith::SelectOp::create (
1368- rewriter, selectOp.getLoc (), selectOp.getType (),
1369- selectOp.getCondition (), fatPtrTrue[0 ], selectOp.getFalseValue ());
1370- rewriter.replaceOpWithMultiple (selectOp, {{newSelectOp, fatPtrTrue[1 ]}});
1371- fatPtrs[{newSelectOp, /* fatPtrOffset*/ fatPtrTrue[1 ]}] =
1372- fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]});
1373- return success ();
1374- }
1375-
13761376 // Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF)
13771377 auto newBase = arith::SelectOp::create (rewriter, selectOp.getLoc (),
13781378 selectOp.getCondition (),
@@ -1381,12 +1381,10 @@ class ConvertArithSelectOp
13811381 selectOp.getCondition (),
13821382 fatPtrTrue[1 ], fatPtrFalse[1 ]);
13831383
1384- assert ((fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]}) ==
1385- fatPtrs.at ({fatPtrFalse[0 ], fatPtrFalse[1 ]})) &&
1386- " expected can narrow to be the same for both fatPtrT and fatPtrF" );
1387-
13881384 rewriter.replaceOpWithMultiple (selectOp, {{newBase, newOffset}});
1389- fatPtrs[{newBase, newOffset}] = fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]});
1385+ fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::intersect (
1386+ fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]}),
1387+ fatPtrs.at ({fatPtrFalse[0 ], fatPtrFalse[1 ]}));
13901388
13911389 return success ();
13921390 }
@@ -1434,14 +1432,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
14341432 assert (i < ifOp.thenYield ().getNumOperands () &&
14351433 i + 1 < ifOp.thenYield ().getNumOperands () &&
14361434 " expected idx to be within bounds of IfOp's results" );
1437- Value thenFatPtrBase = ifOp.thenYield ().getOperand (i);
1438- Value thenFatPtrOffset = ifOp.thenYield ().getOperand (i + 1 );
1439- Value elseFatPtrBase = ifOp.elseYield ().getOperand (i);
1440- Value elseFatPtrOffset = ifOp.elseYield ().getOperand (i + 1 );
1441- assert ((fatPtrs.at ({thenFatPtrBase, thenFatPtrOffset}) ==
1442- fatPtrs.at ({elseFatPtrBase, elseFatPtrOffset})) &&
1443- " expected then fat ptr canNarrow and else fat ptr canNarrow "
1444- " to be equal" );
14451435 }
14461436 }
14471437 }
@@ -1467,8 +1457,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
14671457 for (int64_t idx : yieldPtrOffsets) {
14681458 Value thenFatPtrBase = newIfOp.thenYield ().getOperand (idx);
14691459 Value thenFatPtrOffset = newIfOp.thenYield ().getOperand (idx + 1 );
1470- fatPtrs[{newIfOp.getResult (idx), newIfOp.getResult (idx + 1 )}] =
1471- fatPtrs.at ({thenFatPtrBase, thenFatPtrOffset});
1460+ const auto &thenAttrs = fatPtrs.at ({thenFatPtrBase, thenFatPtrOffset});
1461+ if (withElseRegion) {
1462+ Value elseFatPtrBase = newIfOp.elseYield ().getOperand (idx);
1463+ Value elseFatPtrOffset = newIfOp.elseYield ().getOperand (idx + 1 );
1464+ const auto &elseAttrs = fatPtrs.at ({elseFatPtrBase, elseFatPtrOffset});
1465+ fatPtrs[{newIfOp.getResult (idx), newIfOp.getResult (idx + 1 )}] =
1466+ FatPointers::FatPtrAttrs::intersect (thenAttrs, elseAttrs);
1467+ } else {
1468+ fatPtrs[{newIfOp.getResult (idx), newIfOp.getResult (idx + 1 )}] =
1469+ thenAttrs;
1470+ }
14721471 }
14731472
14741473 ResultRange results = newIfOp.getResults ();
@@ -1708,7 +1707,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
17081707 rewriter.replaceAllUsesExcept (arg, dummyCast.getResult (0 ), dummyCast);
17091708 fatPtrs[{arg, zeroOffset}].canNarrow = true ;
17101709 if (bitness != 64 )
1711- fatPtrs[{arg, zeroOffset}].smallTensorBase = arg ;
1710+ fatPtrs[{arg, zeroOffset}].isSmallTensor = true ;
17121711 }
17131712
17141713 newOp->setDiscardableAttr (kInitFuncArgsRewritten , rewriter.getUnitAttr ());
0 commit comments