Skip to content

Commit dbab3f7

Browse files
authored
[VPlan] Move IV predicate handling to VPlan. (llvm#192876)
1 parent ad6366d commit dbab3f7

14 files changed

Lines changed: 490 additions & 314 deletions

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Loop;
2828
class PredicatedScalarEvolution;
2929
class ScalarEvolution;
3030
class SCEV;
31+
class SCEVPredicate;
3132
class StoreInst;
3233

3334
/// These are the kinds of recurrences that we support.
@@ -405,9 +406,13 @@ class InductionDescriptor {
405406
/// analysis, it can be passed through \p Expr. If the def-use chain
406407
/// associated with the phi includes casts (that we know we can ignore
407408
/// under proper runtime checks), they are passed through \p CastsToIgnore.
409+
/// SCEV predicates checking potential overflow for \p Phi to be an induction,
410+
/// if any, are passed via \p NoWrapPreds and recorded.
408411
LLVM_ABI static bool
409412
isInductionPHI(PHINode *Phi, const Loop *L, ScalarEvolution *SE,
410-
InductionDescriptor &D, const SCEV *Expr = nullptr,
413+
InductionDescriptor &D,
414+
ArrayRef<const SCEVPredicate *> NoWrapPreds = {},
415+
const SCEV *Expr = nullptr,
411416
SmallVectorImpl<Instruction *> *CastsToIgnore = nullptr);
412417

413418
/// Returns true if \p Phi is a floating point induction in the loop \p L.
@@ -449,12 +454,18 @@ class InductionDescriptor {
449454
/// SCEV overflow check.
450455
ArrayRef<Instruction *> getCastInsts() const { return RedundantCasts; }
451456

457+
/// Returns the SCEV predicates associated with this induction.
458+
ArrayRef<const SCEVPredicate *> getNoWrapPredicates() const {
459+
return NoWrapPredicates;
460+
}
461+
452462
private:
453463
/// Private constructor - used by \c isInductionPHI and
454464
/// \c getCanonicalIntInduction.
455465
InductionDescriptor(Value *Start, InductionKind K, const SCEV *Step,
456466
BinaryOperator *InductionBinOp = nullptr,
457-
SmallVectorImpl<Instruction *> *Casts = nullptr);
467+
SmallVectorImpl<Instruction *> *Casts = nullptr,
468+
ArrayRef<const SCEVPredicate *> NoWrapPreds = {});
458469

459470
/// Start value.
460471
TrackingVH<Value> StartValue;
@@ -467,6 +478,8 @@ class InductionDescriptor {
467478
// Instructions used for type-casts of the induction variable,
468479
// that are redundant when guarded with a runtime SCEV overflow check.
469480
SmallVector<Instruction *, 2> RedundantCasts;
481+
// SCEV predicates checking overflow needed for this induction.
482+
SmallVector<const SCEVPredicate *, 2> NoWrapPredicates;
470483
};
471484

472485
} // end namespace llvm

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2657,8 +2657,11 @@ class PredicatedScalarEvolution {
26572657
/// Attempts to produce an AddRecExpr for V by adding additional SCEV
26582658
/// predicates. If we can't transform the expression into an AddRecExpr we
26592659
/// return nullptr and not add additional SCEV predicates to the current
2660-
/// context.
2661-
LLVM_ABI const SCEVAddRecExpr *getAsAddRec(Value *V);
2660+
/// context. If \p WrapPredsAdded is non-null, the required predicates are
2661+
/// collected there instead of being added to this context.
2662+
LLVM_ABI const SCEVAddRecExpr *
2663+
getAsAddRec(Value *V,
2664+
SmallVectorImpl<const SCEVPredicate *> *WrapPredsAdded = nullptr);
26622665

26632666
/// Proves that V doesn't overflow by adding SCEV predicate.
26642667
LLVM_ABI void setNoOverflow(Value *V,
@@ -2680,9 +2683,10 @@ class PredicatedScalarEvolution {
26802683
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const;
26812684

26822685
/// Check if \p AR1 and \p AR2 are equal, while taking into account
2683-
/// Equal predicates in Preds.
2684-
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1,
2685-
const SCEVAddRecExpr *AR2) const;
2686+
/// Equal predicates in Preds and \p ExtraPreds.
2687+
LLVM_ABI bool areAddRecsEqualWithPreds(
2688+
const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2,
2689+
ArrayRef<const SCEVPredicate *> ExtraPreds = {}) const;
26862690

26872691
private:
26882692
/// Increments the version number of the predicate. This needs to be called

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,7 @@ class LoopVectorizationLegality {
659659
/// Updates the vectorization state by adding \p Phi to the inductions list.
660660
/// This can set \p Phi as the main induction of the loop if \p Phi is a
661661
/// better choice for the main induction than the existing one.
662-
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
663-
SmallPtrSetImpl<Value *> &AllowedExit);
662+
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID);
664663

665664
/// The loop that we evaluate.
666665
Loop *TheLoop;
@@ -718,10 +717,6 @@ class LoopVectorizationLegality {
718717
/// Holds the widest induction type encountered.
719718
IntegerType *WidestIndTy = nullptr;
720719

721-
/// Allowed outside users. This holds the variables that can be accessed from
722-
/// outside the loop.
723-
SmallPtrSet<Value *, 4> AllowedExit;
724-
725720
/// Vectorization requirements that will go through late-evaluation.
726721
LoopVectorizationRequirements *Requirements;
727722

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,9 +1381,10 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const {
13811381
return ReductionOperations;
13821382
}
13831383

1384-
InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K,
1385-
const SCEV *Step, BinaryOperator *BOp,
1386-
SmallVectorImpl<Instruction *> *Casts)
1384+
InductionDescriptor::InductionDescriptor(
1385+
Value *Start, InductionKind K, const SCEV *Step, BinaryOperator *BOp,
1386+
SmallVectorImpl<Instruction *> *Casts,
1387+
ArrayRef<const SCEVPredicate *> NoWrapPreds)
13871388
: StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) {
13881389
assert(IK != IK_NoInduction && "Not an induction");
13891390

@@ -1412,6 +1413,7 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K,
14121413

14131414
if (Casts)
14141415
llvm::append_range(RedundantCasts, *Casts);
1416+
llvm::append_range(NoWrapPredicates, NoWrapPreds);
14151417
}
14161418

14171419
InductionDescriptor
@@ -1511,15 +1513,22 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop,
15111513
/// If we are able to find such sequence, we return the instructions
15121514
/// we found, namely %casted_phi and the instructions on its use-def chain up
15131515
/// to the phi (not including the phi).
1514-
static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE,
1515-
const SCEVUnknown *PhiScev,
1516-
const SCEVAddRecExpr *AR,
1517-
SmallVectorImpl<Instruction *> &CastInsts) {
1516+
static bool
1517+
getCastsForInductionPHI(PredicatedScalarEvolution &PSE,
1518+
const SCEVUnknown *PhiScev, const SCEVAddRecExpr *AR,
1519+
SmallVectorImpl<Instruction *> &CastInsts,
1520+
ArrayRef<const SCEVPredicate *> NoWrapPreds) {
15181521

15191522
assert(CastInsts.empty() && "CastInsts is expected to be empty.");
15201523
auto *PN = cast<PHINode>(PhiScev->getValue());
1521-
assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression");
1524+
1525+
// Build a predicate to rewrite SCEVs of values in the cast chain using the
1526+
// predicates needed for this induction.
1527+
ScalarEvolution &SE = *PSE.getSE();
1528+
SCEVUnionPredicate NoWrapUnionPred(NoWrapPreds, SE);
15221529
const Loop *L = AR->getLoop();
1530+
assert(SE.rewriteUsingPredicate(SE.getSCEV(PN), L, NoWrapUnionPred) == AR &&
1531+
"Unexpected phi node SCEV expression");
15231532

15241533
// Find any cast instructions that participate in the def-use chain of
15251534
// PhiScev in the loop.
@@ -1564,8 +1573,10 @@ static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE,
15641573
if (!Inst || !L->contains(Inst)) {
15651574
return false;
15661575
}
1567-
auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val));
1568-
if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR))
1576+
// Create AddRec with NoWrapPredicates applied.
1577+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(
1578+
SE.rewriteUsingPredicate(SE.getSCEV(Val), L, NoWrapUnionPred));
1579+
if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR, NoWrapPreds))
15691580
InCastSequence = true;
15701581
if (InCastSequence) {
15711582
// Only the last instruction in the cast sequence is expected to have
@@ -1603,9 +1614,12 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
16031614
const SCEV *PhiScev = PSE.getSCEV(Phi);
16041615
const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev);
16051616

1617+
// Collect predicates needed to force the SCEV into an AddRecExpr.
1618+
SmallVector<const SCEVPredicate *, 2> Preds;
1619+
16061620
// We need this expression to be an AddRecExpr.
16071621
if (Assume && !AR)
1608-
AR = PSE.getAsAddRec(Phi);
1622+
AR = PSE.getAsAddRec(Phi, &Preds);
16091623

16101624
if (!AR) {
16111625
LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n");
@@ -1621,17 +1635,17 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
16211635
// induction.
16221636
if (PhiScev != AR && SymbolicPhi) {
16231637
SmallVector<Instruction *, 2> Casts;
1624-
if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts))
1625-
return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts);
1638+
if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts, Preds))
1639+
return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, Preds, AR, &Casts);
16261640
}
16271641

1628-
return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR);
1642+
return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, Preds, AR);
16291643
}
16301644

16311645
bool InductionDescriptor::isInductionPHI(
16321646
PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE,
1633-
InductionDescriptor &D, const SCEV *Expr,
1634-
SmallVectorImpl<Instruction *> *CastsToIgnore) {
1647+
InductionDescriptor &D, ArrayRef<const SCEVPredicate *> Preds,
1648+
const SCEV *Expr, SmallVectorImpl<Instruction *> *CastsToIgnore) {
16351649
Type *PhiTy = Phi->getType();
16361650
// isSCEVable returns true for integer and pointer types.
16371651
if (!SE->isSCEVable(PhiTy))
@@ -1671,13 +1685,14 @@ bool InductionDescriptor::isInductionPHI(
16711685
BinaryOperator *BOp =
16721686
dyn_cast<BinaryOperator>(Phi->getIncomingValueForBlock(Latch));
16731687
D = InductionDescriptor(StartValue, IK_IntInduction, Step, BOp,
1674-
CastsToIgnore);
1688+
CastsToIgnore, Preds);
16751689
return true;
16761690
}
16771691

16781692
assert(PhiTy->isPointerTy() && "The PHI must be a pointer");
16791693

16801694
// This allows induction variables w/non-constant steps.
1681-
D = InductionDescriptor(StartValue, IK_PtrInduction, Step);
1695+
D = InductionDescriptor(StartValue, IK_PtrInduction, Step,
1696+
/*InductionBinOp=*/nullptr, /*Casts=*/nullptr, Preds);
16821697
return true;
16831698
}

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5884,14 +5884,17 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
58845884
// even when the following Equal predicate exists:
58855885
// "%step == (sext ix (trunc iy to ix) to iy)".
58865886
bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5887-
const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5887+
const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2,
5888+
ArrayRef<const SCEVPredicate *> NoWrapPreds) const {
58885889
if (AR1 == AR2)
58895890
return true;
58905891

5892+
SCEVUnionPredicate NoWrapUnionPred(NoWrapPreds, SE);
5893+
SCEVUnionPredicate AllPreds = Preds->getUnionWith(&NoWrapUnionPred, SE);
58915894
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
58925895
if (Expr1 != Expr2 &&
5893-
!Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5894-
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5896+
!AllPreds.implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5897+
!AllPreds.implies(SE.getEqualPredicate(Expr2, Expr1), SE))
58955898
return false;
58965899
return true;
58975900
};
@@ -15625,14 +15628,20 @@ bool PredicatedScalarEvolution::hasNoOverflow(
1562515628
return Flags == SCEVWrapPredicate::IncrementAnyWrap;
1562615629
}
1562715630

15628-
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
15631+
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(
15632+
Value *V, SmallVectorImpl<const SCEVPredicate *> *ExtraPreds) {
1562915633
const SCEV *Expr = this->getSCEV(V);
1563015634
SmallVector<const SCEVPredicate *, 4> NewPreds;
1563115635
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
1563215636

1563315637
if (!New)
1563415638
return nullptr;
1563515639

15640+
if (ExtraPreds) {
15641+
ExtraPreds->append(NewPreds);
15642+
return New;
15643+
}
15644+
1563615645
for (const auto *P : NewPreds)
1563715646
addPredicate(*P);
1563815647

0 commit comments

Comments
 (0)