Skip to content

Commit 8663b87

Browse files
authored
[NFC][VectorUtils][TargetTransformInfo] Add isVectorIntrinsicWithOverloadTypeAtArg api (llvm#114849)
This changes allows target intrinsics to specify and overwrite overloaded types. - Updates `ReplaceWithVecLib` to not provide TTI as there most probably won't be a use-case - Updates `SLPVectorizer` to use available TTI - Updates `VPTransformState` to pass down TTI - Updates `VPlanRecipe` to use passed-down TTI This change will let us add scalarization for `asdouble`: llvm#114847
1 parent f7497b1 commit 8663b87

15 files changed

+69
-18
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+13
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,12 @@ class TargetTransformInfo {
901901
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
902902
unsigned ScalarOpdIdx) const;
903903

904+
/// Identifies if the vector form of the intrinsic is overloaded on the type
905+
/// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
906+
/// -1.
907+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
908+
int ScalarOpdIdx) const;
909+
904910
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
905911
/// are set if the demanded result elements need to be inserted and/or
906912
/// extracted from vectors.
@@ -1993,6 +1999,8 @@ class TargetTransformInfo::Concept {
19931999
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
19942000
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19952001
unsigned ScalarOpdIdx) = 0;
2002+
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
2003+
int ScalarOpdIdx) = 0;
19962004
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
19972005
const APInt &DemandedElts,
19982006
bool Insert, bool Extract,
@@ -2569,6 +2577,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25692577
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
25702578
}
25712579

2580+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
2581+
int ScalarOpdIdx) override {
2582+
return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
2583+
}
2584+
25722585
InstructionCost getScalarizationOverhead(VectorType *Ty,
25732586
const APInt &DemandedElts,
25742587
bool Insert, bool Extract,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+5
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ class TargetTransformInfoImplBase {
396396
return false;
397397
}
398398

399+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
400+
int ScalarOpdIdx) const {
401+
return ScalarOpdIdx == -1;
402+
}
403+
399404
InstructionCost getScalarizationOverhead(VectorType *Ty,
400405
const APInt &DemandedElts,
401406
bool Insert, bool Extract,

llvm/include/llvm/Analysis/VectorUtils.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
152152

153153
/// Identifies if the vector form of the intrinsic is overloaded on the type of
154154
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
155-
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
155+
/// \p TTI is used to consider target specific intrinsics, if no target specific
156+
/// intrinsics will be considered then it is appropriate to pass in nullptr.
157+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
158+
const TargetTransformInfo *TTI);
156159

157160
/// Identifies if the vector form of the intrinsic that returns a struct is
158161
/// overloaded at the struct element index \p RetIdx.

llvm/include/llvm/CodeGen/BasicTTIImpl.h

+5
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
801801
return false;
802802
}
803803

804+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
805+
int ScalarOpdIdx) const {
806+
return ScalarOpdIdx == -1;
807+
}
808+
804809
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
805810
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
806811
bool Extract,

llvm/lib/Analysis/TargetTransformInfo.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,11 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
615615
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
616616
}
617617

618+
bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
619+
Intrinsic::ID ID, int ScalarOpdIdx) const {
620+
return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
621+
}
622+
618623
InstructionCost TargetTransformInfo::getScalarizationOverhead(
619624
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
620625
TTI::TargetCostKind CostKind) const {

llvm/lib/Analysis/VectorUtils.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
133133
}
134134
}
135135

136-
bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
137-
int OpdIdx) {
136+
bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
137+
Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) {
138138
assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
139139

140+
if (TTI && Intrinsic::isTargetIntrinsic(ID))
141+
return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
142+
140143
switch (ID) {
141144
case Intrinsic::fptosi_sat:
142145
case Intrinsic::fptoui_sat:

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
110110

111111
// OloadTys collects types used in scalar intrinsic overload name.
112112
SmallVector<Type *, 3> OloadTys;
113-
if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
113+
if (!RetTy->isVoidTy() &&
114+
isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr))
114115
OloadTys.push_back(ScalarRetTy);
115116

116117
// Compute the argument types of the corresponding scalar call and check that
117118
// all vector operands match the previously found EC.
118119
SmallVector<Type *, 8> ScalarArgTypes;
119120
for (auto Arg : enumerate(II->args())) {
120121
auto *ArgTy = Arg.value()->getType();
121-
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index());
122+
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
123+
/*TTI=*/nullptr);
122124
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
123125
ScalarArgTypes.push_back(ArgTy);
124126
if (IsOloadTy)

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
2525
}
2626
}
2727

28+
bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
29+
int ScalarOpdIdx) {
30+
switch (ID) {
31+
default:
32+
return ScalarOpdIdx == -1;
33+
}
34+
}
35+
2836
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
2937
Intrinsic::ID ID) const {
3038
switch (ID) {

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
3737
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
3838
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
3939
unsigned ScalarOpdIdx);
40+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
41+
int ScalarOpdIdx);
4042
};
4143
} // namespace llvm
4244

llvm/lib/Transforms/Scalar/Scalarizer.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
727727

728728
SmallVector<llvm::Type *, 3> Tys;
729729
// Add return type if intrinsic is overloaded on it.
730-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
730+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
731731
Tys.push_back(VS->SplitTy);
732732

733733
if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +767,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
767767
}
768768

769769
Scattered[I] = scatter(&CI, OpI, *OpVS);
770-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
770+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
771771
OverloadIdx[I] = Tys.size();
772772
Tys.push_back(OpVS->SplitTy);
773773
}
774774
} else {
775775
ScalarOperands[I] = OpI;
776-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
776+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
777777
Tys.push_back(OpI->getType());
778778
}
779779
}

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -7684,7 +7684,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76847684
LLVM_DEBUG(BestVPlan.dump());
76857685

76867686
// Perform the actual loop transformation.
7687-
VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan);
7687+
VPTransformState State(&TTI, BestVF, BestUF, LI, DT, ILV.Builder, &ILV,
7688+
&BestVPlan);
76887689

76897690
// 0. Generate SCEV-dependent code into the preheader, including TripCount,
76907691
// before making any changes to the CFG.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -15655,7 +15655,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1565515655
SmallVector<Value *> OpVecs;
1565615656
SmallVector<Type *, 2> TysForDecl;
1565715657
// Add return type if intrinsic is overloaded on it.
15658-
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
15658+
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
1565915659
TysForDecl.push_back(VecTy);
1566015660
auto *CEI = cast<CallInst>(VL0);
1566115661
for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
@@ -15670,7 +15670,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1567015670
It->second.first < DL->getTypeSizeInBits(CEI->getType()))
1567115671
ScalarArg = Builder.getFalse();
1567215672
OpVecs.push_back(ScalarArg);
15673-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
15673+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
1567415674
TysForDecl.push_back(ScalarArg->getType());
1567515675
continue;
1567615676
}
@@ -15692,7 +15692,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1569215692
}
1569315693
LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
1569415694
OpVecs.push_back(OpVec);
15695-
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
15695+
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
1569615696
TysForDecl.push_back(OpVec->getType());
1569715697
}
1569815698

llvm/lib/Transforms/Vectorize/VPlan.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,11 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
219219
return It;
220220
}
221221

222-
VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
222+
VPTransformState::VPTransformState(const TargetTransformInfo *TTI,
223+
ElementCount VF, unsigned UF, LoopInfo *LI,
223224
DominatorTree *DT, IRBuilderBase &Builder,
224225
InnerLoopVectorizer *ILV, VPlan *Plan)
225-
: VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
226+
: TTI(TTI), VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
226227
LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()) {}
227228

228229
Value *VPTransformState::get(VPValue *Def, const VPLane &Lane) {

llvm/lib/Transforms/Vectorize/VPlan.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,11 @@ class VPLane {
234234
/// VPTransformState holds information passed down when "executing" a VPlan,
235235
/// needed for generating the output IR.
236236
struct VPTransformState {
237-
VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
238-
DominatorTree *DT, IRBuilderBase &Builder,
237+
VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF,
238+
LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
239239
InnerLoopVectorizer *ILV, VPlan *Plan);
240+
/// Target Transform Info.
241+
const TargetTransformInfo *TTI;
240242

241243
/// The chosen Vectorization Factor of the loop being vectorized.
242244
ElementCount VF;

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
941941

942942
SmallVector<Type *, 2> TysForDecl;
943943
// Add return type if intrinsic is overloaded on it.
944-
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1))
944+
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1, State.TTI))
945945
TysForDecl.push_back(VectorType::get(getResultType(), State.VF));
946946
SmallVector<Value *, 4> Args;
947947
for (const auto &I : enumerate(operands())) {
@@ -952,7 +952,8 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
952952
Arg = State.get(I.value(), VPLane(0));
953953
else
954954
Arg = State.get(I.value(), onlyFirstLaneUsed(I.value()));
955-
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
955+
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index(),
956+
State.TTI))
956957
TysForDecl.push_back(Arg->getType());
957958
Args.push_back(Arg);
958959
}

0 commit comments

Comments
 (0)