diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index ab5306b7b614e..1c9648af566c5 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1420,8 +1420,8 @@ class TargetTransformInfo { /// \return The expected cost of a sign- or zero-extended vector extract. Use /// Index = -1 to indicate that there is no information about the index value. InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const; + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const; /// \return The expected cost of control-flow related instructions such as /// Phi, Ret, Br, Switch. @@ -2210,9 +2210,10 @@ class TargetTransformInfo::Concept { Type *Src, CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) const = 0; - virtual InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const = 0; + virtual InstructionCost + getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index, + TTI::TargetCostKind CostKind) const = 0; virtual InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) const = 0; @@ -2947,10 +2948,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { const Instruction *I) const override { return Impl.getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } - InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const override { - return Impl.getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + InstructionCost + getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index, + TTI::TargetCostKind CostKind) const override { + return Impl.getExtractWithExtendCost(Opcode, Dst, VecTy, Index, CostKind); } InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index b46eb349c2249..0828eb2ad8be6 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -700,8 +700,8 @@ class TargetTransformInfoImplBase { } InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const { + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const { return 1; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index db5fb2f7f1a54..0ef6bf5d45f4d 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1333,9 +1333,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { } InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const { - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const { return thisT()->getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index, nullptr, nullptr) + thisT()->getCastInstrCost(Opcode, Dst, VecTy->getElementType(), diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 3f97484fb2fa3..981087372b9dd 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1049,9 +1049,10 @@ InstructionCost TargetTransformInfo::getCastInstrCost( } InstructionCost TargetTransformInfo::getExtractWithExtendCost( - unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) const { + unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const { InstructionCost Cost = - TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index, CostKind); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 59291c02e6555..a20f1c104834d 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -3557,10 +3557,10 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); } -InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, - Type *Dst, - VectorType *VecTy, - unsigned Index) const { +InstructionCost +AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const { // Make sure we were given a valid extend opcode. assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && @@ -3575,7 +3575,6 @@ InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, // Get the cost for the extract. We compute the cost (if any) for the extend // below. - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index, nullptr, nullptr); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 7da2820bee323..782fd5bc2003c 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -192,8 +192,8 @@ class AArch64TTIImpl : public BasicTTIImplBase { const Instruction *I = nullptr) const; InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, - VectorType *VecTy, - unsigned Index) const; + VectorType *VecTy, unsigned Index, + TTI::TargetCostKind CostKind) const; InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) const; diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index a56728226c039..0ea2212aeeefb 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -5793,7 +5793,7 @@ static InstructionCost getExtractWithExtendCost( TTI.getCastInstrCost(Opcode, Dst, SubTp, TTI::CastContextHint::None, CostKind); } - return TTI.getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + return TTI.getExtractWithExtendCost(Opcode, Dst, VecTy, Index, CostKind); } /// Correctly creates insert_subvector, checking that the index is multiple of @@ -12412,9 +12412,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { all_of(Ext->users(), IsaPred)) { // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. - Cost -= - TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), - EE->getVectorOperandType(), Idx); + Cost -= TTI.getExtractWithExtendCost( + Ext->getOpcode(), Ext->getType(), EE->getVectorOperandType(), + Idx, CostKind); // Add back the cost of s|zext which is subtracted separately. Cost += TTI.getCastInstrCost( Ext->getOpcode(), Ext->getType(), EE->getType(), @@ -13035,7 +13035,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. InstructionCost Cost = TTI->getExtractWithExtendCost( - Ext->getOpcode(), Ext->getType(), SrcVecTy, *getExtractIndex(I)); + Ext->getOpcode(), Ext->getType(), SrcVecTy, *getExtractIndex(I), + CostKind); // Subtract the cost of s|zext which is subtracted separately. Cost -= TTI->getCastInstrCost( Ext->getOpcode(), Ext->getType(), I->getType(),