Skip to content

[LLVM][CostModel][AArch64] Remove magic numbers from f16 vector compares. #135795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

paulwalker-arm
Copy link
Collaborator

The PR also extends the code to cover bfloat vector compares that are also promoted to float.

NOTE: There is a bail out for the compares that are scalarised that will be removed by #135398.

…res.

The PR also extends the code to cover bfloat vector compares that are
also promoted to float.

NOTE: There is a bail out for the compares that are scalarised that
will be removed by llvm#135398.
@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

The PR also extends the code to cover bfloat vector compares that are also promoted to float.

NOTE: There is a bail out for the compares that are scalarised that will be removed by #135398.


Full diff: https://github.com/llvm/llvm-project/pull/135795.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+28-4)
  • (modified) llvm/test/Analysis/CostModel/AArch64/vector-select.ll (+8-8)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 2b9d32f9208fe..f79b8277b4cd1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4236,10 +4236,34 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
   }
 
   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
-    auto LT = getTypeLegalizationCost(ValTy);
-    // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back.
-    if (LT.second == MVT::v4f16 && !ST->hasFullFP16())
-      return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn
+    Type *ValScalarTy = ValTy->getScalarType();
+    if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) ||
+        ValScalarTy->isBFloatTy()) {
+      auto *ValVTy = cast<FixedVectorType>(ValTy);
+
+      // FIXME: We currently scalarise these.
+      if (ValVTy->getNumElements() > 4)
+        return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred,
+                                         CostKind, Op1Info, Op2Info, I);
+
+      // Without dedicated instructions we promote [b]f16 compares to f32.
+      auto *PromotedTy =
+          VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy);
+
+      InstructionCost Cost = 0;
+      // Promte operands to float vectors.
+      Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
+                                   TTI::CastContextHint::None, CostKind);
+      // Compare float vectors.
+      Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
+                                 Op1Info, Op2Info);
+      // During codegen we'll truncate the vector result from i32 to i16.
+      Cost +=
+          getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy),
+                           VectorType::getInteger(PromotedTy),
+                           TTI::CastContextHint::None, CostKind);
+      return Cost;
+    }
   }
 
   // Treat the icmp in icmp(and, 0) as free, as we can make use of ands.
diff --git a/llvm/test/Analysis/CostModel/AArch64/vector-select.ll b/llvm/test/Analysis/CostModel/AArch64/vector-select.ll
index c2256159a8ee2..e66f94dd54f21 100644
--- a/llvm/test/Analysis/CostModel/AArch64/vector-select.ll
+++ b/llvm/test/Analysis/CostModel/AArch64/vector-select.ll
@@ -168,7 +168,7 @@ define <2 x double> @v2f64_select_ogt(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_ogt(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_ogt'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp ogt <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp ogt <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -255,7 +255,7 @@ define <2 x double> @v2f64_select_oge(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_oge(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_oge'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp oge <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp oge <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -342,7 +342,7 @@ define <2 x double> @v2f64_select_olt(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_olt(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_olt'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp olt <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp olt <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -429,7 +429,7 @@ define <2 x double> @v2f64_select_ole(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_ole(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_ole'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp ole <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp ole <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -516,7 +516,7 @@ define <2 x double> @v2f64_select_oeq(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_oeq(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_oeq'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp oeq <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp oeq <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -603,7 +603,7 @@ define <2 x double> @v2f64_select_one(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_one(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_one'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp one <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp one <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -690,7 +690,7 @@ define <2 x double> @v2f64_select_une(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_une(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_une'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp une <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp une <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;
@@ -777,7 +777,7 @@ define <2 x double> @v2f64_select_ord(<2 x double> %a, <2 x double> %b, <2 x dou
 
 define <4 x bfloat> @v4bf16_select_ord(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
 ; COST-LABEL: 'v4bf16_select_ord'
-; COST-NEXT:  Cost Model: Found costs of 1 for: %cmp.1 = fcmp ord <4 x bfloat> %a, %b
+; COST-NEXT:  Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp ord <4 x bfloat> %a, %b
 ; COST-NEXT:  Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
 ; COST-NEXT:  Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
 ;

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants