diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 2efca0d1d754f..6da16c151f8bf 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1420,8 +1420,8 @@ class TargetTransformInfo { /// \return The expected cost of a sign- or zero-extended vector extract. Use /// Index = -1 to indicate that there is no information about the index value. InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const; + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const; /// \return The expected cost of control-flow related instructions such as /// Phi, Ret, Br, Switch. @@ -2196,9 +2196,9 @@ class TargetTransformInfo::Concept { Type *Src, CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) = 0; - virtual InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) = 0; + virtual InstructionCost + getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index, TTI::TargetCostKind CostKind) = 0; virtual InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) = 0; @@ -2919,10 +2919,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { const Instruction *I) override { return Impl.getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } - InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) override { - return Impl.getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + InstructionCost + getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index, + TTI::TargetCostKind CostKind) override { + return Impl.getExtractWithExtendCost(Opcode, Dst, VecTy, Index, CostKind); } InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 3fe0a9101fdee..12f05babb3e05 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -696,8 +696,8 @@ class TargetTransformInfoImplBase { } InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const { + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const { return 1; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index eacf75c24695f..1081e0249c6bf 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1333,8 +1333,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { } InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, unsigned Index) { - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) { return thisT()->getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index, nullptr, nullptr) + thisT()->getCastInstrCost(Opcode, Dst, VecTy->getElementType(), diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 4fea4e5711f5a..a8bc7066e7efb 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1049,9 +1049,10 @@ InstructionCost TargetTransformInfo::getCastInstrCost( } InstructionCost TargetTransformInfo::getExtractWithExtendCost( - unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) const { + unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const { InstructionCost Cost = - TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index, CostKind); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index ca1a486901951..b44e4317e269e 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -3556,10 +3556,10 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); } -InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, - Type *Dst, - VectorType *VecTy, - unsigned Index) { +InstructionCost +AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) { // Make sure we were given a valid extend opcode. assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && @@ -3574,7 +3574,6 @@ InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, // Get the cost for the extract. We compute the cost (if any) for the extend // below. - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index, nullptr, nullptr); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index ae0df6b895ec8..4d1c7a1e37836 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -189,7 +189,8 @@ class AArch64TTIImpl : public BasicTTIImplBase { const Instruction *I = nullptr); InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, unsigned Index); + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind); InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 50c403906daa9..a8b8d32552176 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -5451,7 +5451,7 @@ static InstructionCost getExtractWithExtendCost( TTI.getCastInstrCost(Opcode, Dst, SubTp, TTI::CastContextHint::None, CostKind); } - return TTI.getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + return TTI.getExtractWithExtendCost(Opcode, Dst, VecTy, Index, CostKind); } /// Correctly creates insert_subvector, checking that the index is multiple of @@ -12045,9 +12045,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { all_of(Ext->users(), IsaPred)) { // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. - Cost -= - TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), - EE->getVectorOperandType(), Idx); + Cost -= TTI.getExtractWithExtendCost( + Ext->getOpcode(), Ext->getType(), EE->getVectorOperandType(), + Idx, CostKind); // Add back the cost of s|zext which is subtracted separately. Cost += TTI.getCastInstrCost( Ext->getOpcode(), Ext->getType(), EE->getType(), @@ -12668,7 +12668,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. InstructionCost Cost = TTI->getExtractWithExtendCost( - Ext->getOpcode(), Ext->getType(), SrcVecTy, *getExtractIndex(I)); + Ext->getOpcode(), Ext->getType(), SrcVecTy, *getExtractIndex(I), + CostKind); // Subtract the cost of s|zext which is subtracted separately. Cost -= TTI->getCastInstrCost( Ext->getOpcode(), Ext->getType(), I->getType(),