Skip to content

Commit 8663b87

Browse files
authored
[NFC][VectorUtils][TargetTransformInfo] Add isVectorIntrinsicWithOverloadTypeAtArg api (llvm#114849)
This changes allows target intrinsics to specify and overwrite overloaded types. - Updates `ReplaceWithVecLib` to not provide TTI as there most probably won't be a use-case - Updates `SLPVectorizer` to use available TTI - Updates `VPTransformState` to pass down TTI - Updates `VPlanRecipe` to use passed-down TTI This change will let us add scalarization for `asdouble`: llvm#114847
1 parent f7497b1 commit 8663b87

File tree

15 files changed

+69
-18
lines changed

15 files changed

+69
-18
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,12 @@ class TargetTransformInfo {
901901
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
902902
unsigned ScalarOpdIdx) const;
903903

904+
/// Identifies if the vector form of the intrinsic is overloaded on the type
905+
/// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
906+
/// -1.
907+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
908+
int ScalarOpdIdx) const;
909+
904910
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
905911
/// are set if the demanded result elements need to be inserted and/or
906912
/// extracted from vectors.
@@ -1993,6 +1999,8 @@ class TargetTransformInfo::Concept {
19931999
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
19942000
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19952001
unsigned ScalarOpdIdx) = 0;
2002+
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
2003+
int ScalarOpdIdx) = 0;
19962004
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
19972005
const APInt &DemandedElts,
19982006
bool Insert, bool Extract,
@@ -2569,6 +2577,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25692577
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
25702578
}
25712579

2580+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
2581+
int ScalarOpdIdx) override {
2582+
return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
2583+
}
2584+
25722585
InstructionCost getScalarizationOverhead(VectorType *Ty,
25732586
const APInt &DemandedElts,
25742587
bool Insert, bool Extract,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ class TargetTransformInfoImplBase {
396396
return false;
397397
}
398398

399+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
400+
int ScalarOpdIdx) const {
401+
return ScalarOpdIdx == -1;
402+
}
403+
399404
InstructionCost getScalarizationOverhead(VectorType *Ty,
400405
const APInt &DemandedElts,
401406
bool Insert, bool Extract,

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
152152

153153
/// Identifies if the vector form of the intrinsic is overloaded on the type of
154154
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
155-
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
155+
/// \p TTI is used to consider target specific intrinsics, if no target specific
156+
/// intrinsics will be considered then it is appropriate to pass in nullptr.
157+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
158+
const TargetTransformInfo *TTI);
156159

157160
/// Identifies if the vector form of the intrinsic that returns a struct is
158161
/// overloaded at the struct element index \p RetIdx.

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
801801
return false;
802802
}
803803

804+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
805+
int ScalarOpdIdx) const {
806+
return ScalarOpdIdx == -1;
807+
}
808+
804809
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
805810
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
806811
bool Extract,

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,11 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
615615
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
616616
}
617617

618+
bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
619+
Intrinsic::ID ID, int ScalarOpdIdx) const {
620+
return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
621+
}
622+
618623
InstructionCost TargetTransformInfo::getScalarizationOverhead(
619624
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
620625
TTI::TargetCostKind CostKind) const {

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
133133
}
134134
}
135135

136-
bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
137-
int OpdIdx) {
136+
bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
137+
Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) {
138138
assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
139139

140+
if (TTI && Intrinsic::isTargetIntrinsic(ID))
141+
return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
142+
140143
switch (ID) {
141144
case Intrinsic::fptosi_sat:
142145
case Intrinsic::fptoui_sat:

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
110110

111111
// OloadTys collects types used in scalar intrinsic overload name.
112112
SmallVector<Type *, 3> OloadTys;
113-
if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
113+
if (!RetTy->isVoidTy() &&
114+
isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr))
114115
OloadTys.push_back(ScalarRetTy);
115116

116117
// Compute the argument types of the corresponding scalar call and check that
117118
// all vector operands match the previously found EC.
118119
SmallVector<Type *, 8> ScalarArgTypes;
119120
for (auto Arg : enumerate(II->args())) {
120121
auto *ArgTy = Arg.value()->getType();
121-
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index());
122+
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
123+
/*TTI=*/nullptr);
122124
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
123125
ScalarArgTypes.push_back(ArgTy);
124126
if (IsOloadTy)

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
2525
}
2626
}
2727

28+
bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
29+
int ScalarOpdIdx) {
30+
switch (ID) {
31+
default:
32+
return ScalarOpdIdx == -1;
33+
}
34+
}
35+
2836
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
2937
Intrinsic::ID ID) const {
3038
switch (ID) {

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
3737
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
3838
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
3939
unsigned ScalarOpdIdx);
40+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
41+
int ScalarOpdIdx);
4042
};
4143
} // namespace llvm
4244

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
727727

728728
SmallVector<llvm::Type *, 3> Tys;
729729
// Add return type if intrinsic is overloaded on it.
730-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
730+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
731731
Tys.push_back(VS->SplitTy);
732732

733733
if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +767,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
767767
}
768768

769769
Scattered[I] = scatter(&CI, OpI, *OpVS);
770-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
770+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
771771
OverloadIdx[I] = Tys.size();
772772
Tys.push_back(OpVS->SplitTy);
773773
}
774774
} else {
775775
ScalarOperands[I] = OpI;
776-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
776+
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
777777
Tys.push_back(OpI->getType());
778778
}
779779
}

0 commit comments

Comments
 (0)