Skip to content

[NVPTX] Add TLI hook for load slice cost and implement it #131847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

AlexMaclean
Copy link
Member

Add a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting.

@AlexMaclean AlexMaclean self-assigned this Mar 18, 2025
@llvmbot llvmbot added backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well labels Mar 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 18, 2025

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Add a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting.


Full diff: https://github.com/llvm/llvm-project/pull/131847.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+14)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+8-28)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+17)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5)
  • (added) llvm/test/CodeGen/NVPTX/load-slice.ll (+54)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a3fb4e9a8513b..9b144849fbfdb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3133,6 +3133,20 @@ class TargetLoweringBase {
     return false;
   }
 
+  virtual unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+                                    unsigned CrossRegisterBanksCopies,
+                                    unsigned Truncates, unsigned ZExts,
+                                    unsigned Shifts) const {
+    // Assume cross register banks copies are as expensive as loads.
+    unsigned ExpensiveOps = Loads + CrossRegisterBanksCopies;
+
+    // Unless we are optimizing for code size, prioritize expensive operations.
+    if (!ForCodeSize)
+      ExpensiveOps = ExpensiveOps * 20;
+
+    return Truncates + ZExts + Shifts + ExpensiveOps;
+  }
+
   /// Return true if the target has a vector blend instruction.
   virtual bool hasVectorBlend() const { return false; }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a54857e1037e2..624a2b032ccae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -19753,32 +19753,10 @@ struct LoadedSlice {
       return *this;
     }
 
-    bool operator==(const Cost &RHS) const {
-      return Loads == RHS.Loads && Truncates == RHS.Truncates &&
-             CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
-             ZExts == RHS.ZExts && Shift == RHS.Shift;
+    unsigned value(const TargetLowering &TLI) const {
+      return TLI.getLoadSliceCost(ForCodeSize, Loads, CrossRegisterBanksCopies,
+                                  Truncates, ZExts, Shift);
     }
-
-    bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
-
-    bool operator<(const Cost &RHS) const {
-      // Assume cross register banks copies are as expensive as loads.
-      // FIXME: Do we want some more target hooks?
-      unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
-      unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
-      // Unless we are optimizing for code size, consider the
-      // expensive operation first.
-      if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
-        return ExpensiveOpsLHS < ExpensiveOpsRHS;
-      return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
-             (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
-    }
-
-    bool operator>(const Cost &RHS) const { return RHS < *this; }
-
-    bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
-
-    bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
   };
 
   // The last instruction that represent the slice. This should be a
@@ -20099,7 +20077,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
 /// FIXME: When the cost model will be mature enough, we can relax
 /// constraints (1) and (2).
 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
-                                const APInt &UsedBits, bool ForCodeSize) {
+                                const APInt &UsedBits, bool ForCodeSize,
+                                const TargetLowering &TLI) {
   unsigned NumberOfSlices = LoadedSlices.size();
   if (StressLoadSlicing)
     return NumberOfSlices > 1;
@@ -20129,7 +20108,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
 
   // If the target supports paired load, adjust the cost accordingly.
   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
-  return OrigCost > GlobalSlicingCost;
+  return OrigCost.value(TLI) > GlobalSlicingCost.value(TLI);
 }
 
 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
@@ -20209,7 +20188,8 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
   }
 
   // Abort slicing if it does not seem to be profitable.
-  if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
+  if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize,
+                           DAG.getTargetLoweringInfo()))
     return false;
 
   ++SlicedLoads;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 18ec5c5384488..482822f9425bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4735,6 +4735,23 @@ bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
   return true;
 }
 
+unsigned NVPTXTargetLowering::getLoadSliceCost(
+    bool ForCodeSize, unsigned Loads, unsigned CrossRegisterBanksCopies,
+    unsigned Truncates, unsigned ZExts, unsigned Shifts) const {
+
+  // Loads are much more expensive than other operations, and the cost of extra
+  // load is not offset by savings from shift/mask if the usage of the load is
+  // as split elements.
+  //
+  // Base TLI treats CrossRegisterBanksCopies as expensive, but these operations
+  // can be optimized in most cases for NVPTX.
+  //
+  CrossRegisterBanksCopies = 0;
+
+  return TargetLoweringBase::getLoadSliceCost(
+      ForCodeSize, Loads, CrossRegisterBanksCopies, Truncates, ZExts, Shifts);
+}
+
 //===----------------------------------------------------------------------===//
 //                         NVPTX Inline Assembly Support
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index ff0241886223b..95c4de4d68ca5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -143,6 +143,11 @@ class NVPTXTargetLowering : public TargetLowering {
                              unsigned AS,
                              Instruction *I = nullptr) const override;
 
+  unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+                            unsigned CrossRegisterBanksCopies,
+                            unsigned Truncates, unsigned ZExts,
+                            unsigned Shifts) const override;
+
   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
     // Truncating 64-bit to 32-bit is free in SASS.
     if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
diff --git a/llvm/test/CodeGen/NVPTX/load-slice.ll b/llvm/test/CodeGen/NVPTX/load-slice.ll
new file mode 100644
index 0000000000000..c34f4a27f8d36
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/load-slice.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-unknown-unknown"
+
+;; Verify that 64-bit loads are not split into more 32-bit
+;; loads. Loads are more expensive than shifts/conversions.
+define float @test(ptr %in) {
+;
+; CHECK-LABEL: test(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .f32 %f<8>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [test_param_0];
+; CHECK-NEXT:    ld.u64 %rd2, [%rd1];
+; CHECK-NEXT:    ld.u64 %rd3, [%rd1+8];
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT:    cvt.u32.u64 %r2, %rd3;
+; CHECK-NEXT:    mov.b32 %f1, %r1;
+; CHECK-NEXT:    mov.b32 %f2, %r2;
+; CHECK-NEXT:    add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT:    { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
+; CHECK-NEXT:    { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
+; CHECK-NEXT:    mov.b32 %f4, %r3;
+; CHECK-NEXT:    mov.b32 %f5, %r4;
+; CHECK-NEXT:    add.rn.f32 %f6, %f4, %f5;
+; CHECK-NEXT:    add.rn.f32 %f7, %f3, %f6;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f7;
+; CHECK-NEXT:    ret;
+  %ptr0 = getelementptr inbounds i64, ptr %in, i64 0
+  %ptr1 = getelementptr inbounds i64, ptr %in, i64 1
+
+  %load0 = load i64, ptr %ptr0, align 8
+  %load1 = load i64, ptr %ptr1, align 8
+  %trunc_lo_0 = trunc i64 %load0 to i32
+  %trunc_lo_1 = trunc i64 %load1 to i32
+  %float_lo_0 = bitcast i32 %trunc_lo_0 to float
+  %float_lo_1 = bitcast i32 %trunc_lo_1 to float
+  %add_lo = fadd float %float_lo_0, %float_lo_1
+
+  %shift0 = lshr i64 %load0, 32
+  %shift1 = lshr i64 %load1, 32
+  %trunc_hi_0 = trunc i64 %shift0 to i32
+  %trunc_hi_1 = trunc i64 %shift1 to i32
+  %float_hi_0 = bitcast i32 %trunc_hi_0 to float
+  %float_hi_1 = bitcast i32 %trunc_hi_1 to float
+  %add_hi = fadd float %float_hi_0, %float_hi_1
+
+  %res = fadd float %add_lo, %add_hi
+  ret float %res
+}

Comment on lines +3143 to +3145
// Unless we are optimizing for code size, prioritize expensive operations.
if (!ForCodeSize)
ExpensiveOps = ExpensiveOps * 20;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit suggestion: "prioritize expensive operations." -> "prefer avoiding expensive ops" or "prefer less expensive ops".

Where does the magic cost scaling factor of 20 come from? I can't match it to the logic this weight function replaces.
It would be prudent to make the new defaults behave close to how they worked before the change. This is in the code path that will likely affect almost everything. We don't want any surprises.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 20 is quite arbitrary. In the previous implementation we first compared expensive ops and only checked inexpensive ops if expensive ops were equal. This meant that expensive ops always trumped inexpensive ops. I don't think that previous logic really made sense as there must be some number of inexpensive ops that is more costly than an expensive op. I chose 20 because in practice I don't think there are going to be any cases where there are 20 inexpensive ops, so this should essentially get us the same behavior. I agree this is a bit weird though, I'm happy to try out any alternatives you can suggest.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cost heuristics are prone to cause unexpected surprises, and I do not see a good way to validate the currently selected value without exposing it to the real world use cases. Perhaps we can make that scaling factor tunable as an escape hatch in case the current value produces suboptimal results for some users.

Comment on lines +20 to +28
; CHECK-NEXT: cvt.u32.u64 %r1, %rd2;
; CHECK-NEXT: cvt.u32.u64 %r2, %rd3;
; CHECK-NEXT: mov.b32 %f1, %r1;
; CHECK-NEXT: mov.b32 %f2, %r2;
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
; CHECK-NEXT: mov.b32 %f4, %r3;
; CHECK-NEXT: mov.b32 %f5, %r4;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future cleanup opportunity: We seem to be using too many instructions to do effectively nothing on the SASS level. FP operations accept b32 registers as inputs, so we can skip the moves between .b32 and .f32 registers and just do add.rn.f32 %f3, %r1, %r2; That also allows splitting 64-bit values in one mov operation.

https://godbolt.org/z/1bb5dMGsa

    mov.b64 {%r1, %r3}, %rd2;
    mov.b64 {%r2, %r4}, %rd3;
    add.rn.f32 %f3, %r1, %r2;
    add.rn.f32 %f6, %r3, %r4;
    add.rn.f32 %f7, %f3, %f6;

Or, alternatively, split .b64 -> {.f32, .f32}. That would avoid having to deal with FP ops accepting integers as inputs.: https://godbolt.org/z/64d9of9d6

    mov.b64 {%f1, %f3}, %rd2;
    mov.b64 {%f2, %f4}, %rd3;
    add.rn.f32 %f5, %f1, %f2;
    add.rn.f32 %f6, %f3, %f4;
    add.rn.f32 %f7, %f5, %f6;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I completely agree we're producing pretty unsightly PTX here, even if the SASS is efficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants