-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[NVPTX] Add TLI hook for load slice cost and implement it #131847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVPTX] Add TLI hook for load slice cost and implement it #131847
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesAdd a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting. Full diff: https://github.com/llvm/llvm-project/pull/131847.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a3fb4e9a8513b..9b144849fbfdb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3133,6 +3133,20 @@ class TargetLoweringBase {
return false;
}
+ virtual unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+ unsigned CrossRegisterBanksCopies,
+ unsigned Truncates, unsigned ZExts,
+ unsigned Shifts) const {
+ // Assume cross register banks copies are as expensive as loads.
+ unsigned ExpensiveOps = Loads + CrossRegisterBanksCopies;
+
+ // Unless we are optimizing for code size, prioritize expensive operations.
+ if (!ForCodeSize)
+ ExpensiveOps = ExpensiveOps * 20;
+
+ return Truncates + ZExts + Shifts + ExpensiveOps;
+ }
+
/// Return true if the target has a vector blend instruction.
virtual bool hasVectorBlend() const { return false; }
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a54857e1037e2..624a2b032ccae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -19753,32 +19753,10 @@ struct LoadedSlice {
return *this;
}
- bool operator==(const Cost &RHS) const {
- return Loads == RHS.Loads && Truncates == RHS.Truncates &&
- CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
- ZExts == RHS.ZExts && Shift == RHS.Shift;
+ unsigned value(const TargetLowering &TLI) const {
+ return TLI.getLoadSliceCost(ForCodeSize, Loads, CrossRegisterBanksCopies,
+ Truncates, ZExts, Shift);
}
-
- bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
-
- bool operator<(const Cost &RHS) const {
- // Assume cross register banks copies are as expensive as loads.
- // FIXME: Do we want some more target hooks?
- unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
- unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
- // Unless we are optimizing for code size, consider the
- // expensive operation first.
- if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
- return ExpensiveOpsLHS < ExpensiveOpsRHS;
- return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
- (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
- }
-
- bool operator>(const Cost &RHS) const { return RHS < *this; }
-
- bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
-
- bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
};
// The last instruction that represent the slice. This should be a
@@ -20099,7 +20077,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
/// FIXME: When the cost model will be mature enough, we can relax
/// constraints (1) and (2).
static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
- const APInt &UsedBits, bool ForCodeSize) {
+ const APInt &UsedBits, bool ForCodeSize,
+ const TargetLowering &TLI) {
unsigned NumberOfSlices = LoadedSlices.size();
if (StressLoadSlicing)
return NumberOfSlices > 1;
@@ -20129,7 +20108,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
// If the target supports paired load, adjust the cost accordingly.
adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
- return OrigCost > GlobalSlicingCost;
+ return OrigCost.value(TLI) > GlobalSlicingCost.value(TLI);
}
/// If the given load, \p LI, is used only by trunc or trunc(lshr)
@@ -20209,7 +20188,8 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
}
// Abort slicing if it does not seem to be profitable.
- if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
+ if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize,
+ DAG.getTargetLoweringInfo()))
return false;
++SlicedLoads;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 18ec5c5384488..482822f9425bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4735,6 +4735,23 @@ bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
return true;
}
+unsigned NVPTXTargetLowering::getLoadSliceCost(
+ bool ForCodeSize, unsigned Loads, unsigned CrossRegisterBanksCopies,
+ unsigned Truncates, unsigned ZExts, unsigned Shifts) const {
+
+ // Loads are much more expensive than other operations, and the cost of extra
+ // load is not offset by savings from shift/mask if the usage of the load is
+ // as split elements.
+ //
+ // Base TLI treats CrossRegisterBanksCopies as expensive, but these operations
+ // can be optimized in most cases for NVPTX.
+ //
+ CrossRegisterBanksCopies = 0;
+
+ return TargetLoweringBase::getLoadSliceCost(
+ ForCodeSize, Loads, CrossRegisterBanksCopies, Truncates, ZExts, Shifts);
+}
+
//===----------------------------------------------------------------------===//
// NVPTX Inline Assembly Support
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index ff0241886223b..95c4de4d68ca5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -143,6 +143,11 @@ class NVPTXTargetLowering : public TargetLowering {
unsigned AS,
Instruction *I = nullptr) const override;
+ unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
+ unsigned CrossRegisterBanksCopies,
+ unsigned Truncates, unsigned ZExts,
+ unsigned Shifts) const override;
+
bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
// Truncating 64-bit to 32-bit is free in SASS.
if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
diff --git a/llvm/test/CodeGen/NVPTX/load-slice.ll b/llvm/test/CodeGen/NVPTX/load-slice.ll
new file mode 100644
index 0000000000000..c34f4a27f8d36
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/load-slice.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-unknown-unknown"
+
+;; Verify that 64-bit loads are not split into more 32-bit
+;; loads. Loads are more expensive than shifts/conversions.
+define float @test(ptr %in) {
+;
+; CHECK-LABEL: test(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .f32 %f<8>;
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u64 %rd1, [test_param_0];
+; CHECK-NEXT: ld.u64 %rd2, [%rd1];
+; CHECK-NEXT: ld.u64 %rd3, [%rd1+8];
+; CHECK-NEXT: cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT: cvt.u32.u64 %r2, %rd3;
+; CHECK-NEXT: mov.b32 %f1, %r1;
+; CHECK-NEXT: mov.b32 %f2, %r2;
+; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
+; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
+; CHECK-NEXT: mov.b32 %f4, %r3;
+; CHECK-NEXT: mov.b32 %f5, %r4;
+; CHECK-NEXT: add.rn.f32 %f6, %f4, %f5;
+; CHECK-NEXT: add.rn.f32 %f7, %f3, %f6;
+; CHECK-NEXT: st.param.f32 [func_retval0], %f7;
+; CHECK-NEXT: ret;
+ %ptr0 = getelementptr inbounds i64, ptr %in, i64 0
+ %ptr1 = getelementptr inbounds i64, ptr %in, i64 1
+
+ %load0 = load i64, ptr %ptr0, align 8
+ %load1 = load i64, ptr %ptr1, align 8
+ %trunc_lo_0 = trunc i64 %load0 to i32
+ %trunc_lo_1 = trunc i64 %load1 to i32
+ %float_lo_0 = bitcast i32 %trunc_lo_0 to float
+ %float_lo_1 = bitcast i32 %trunc_lo_1 to float
+ %add_lo = fadd float %float_lo_0, %float_lo_1
+
+ %shift0 = lshr i64 %load0, 32
+ %shift1 = lshr i64 %load1, 32
+ %trunc_hi_0 = trunc i64 %shift0 to i32
+ %trunc_hi_1 = trunc i64 %shift1 to i32
+ %float_hi_0 = bitcast i32 %trunc_hi_0 to float
+ %float_hi_1 = bitcast i32 %trunc_hi_1 to float
+ %add_hi = fadd float %float_hi_0, %float_hi_1
+
+ %res = fadd float %add_lo, %add_hi
+ ret float %res
+}
|
// Unless we are optimizing for code size, prioritize expensive operations. | ||
if (!ForCodeSize) | ||
ExpensiveOps = ExpensiveOps * 20; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit suggestion: "prioritize expensive operations." -> "prefer avoiding expensive ops" or "prefer less expensive ops".
Where does the magic cost scaling factor of 20 come from? I can't match it to the logic this weight function replaces.
It would be prudent to make the new defaults behave close to how they worked before the change. This is in the code path that will likely affect almost everything. We don't want any surprises.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 20 is quite arbitrary. In the previous implementation we first compared expensive ops and only checked inexpensive ops if expensive ops were equal. This meant that expensive ops always trumped inexpensive ops. I don't think that previous logic really made sense as there must be some number of inexpensive ops that is more costly than an expensive op. I chose 20 because in practice I don't think there are going to be any cases where there are 20 inexpensive ops, so this should essentially get us the same behavior. I agree this is a bit weird though, I'm happy to try out any alternatives you can suggest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cost heuristics are prone to cause unexpected surprises, and I do not see a good way to validate the currently selected value without exposing it to the real world use cases. Perhaps we can make that scaling factor tunable as an escape hatch in case the current value produces suboptimal results for some users.
; CHECK-NEXT: cvt.u32.u64 %r1, %rd2; | ||
; CHECK-NEXT: cvt.u32.u64 %r2, %rd3; | ||
; CHECK-NEXT: mov.b32 %f1, %r1; | ||
; CHECK-NEXT: mov.b32 %f2, %r2; | ||
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2; | ||
; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; } | ||
; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; } | ||
; CHECK-NEXT: mov.b32 %f4, %r3; | ||
; CHECK-NEXT: mov.b32 %f5, %r4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future cleanup opportunity: We seem to be using too many instructions to do effectively nothing on the SASS level. FP operations accept b32
registers as inputs, so we can skip the moves between .b32 and .f32 registers and just do add.rn.f32 %f3, %r1, %r2;
That also allows splitting 64-bit values in one mov operation.
https://godbolt.org/z/1bb5dMGsa
mov.b64 {%r1, %r3}, %rd2;
mov.b64 {%r2, %r4}, %rd3;
add.rn.f32 %f3, %r1, %r2;
add.rn.f32 %f6, %r3, %r4;
add.rn.f32 %f7, %f3, %f6;
Or, alternatively, split .b64 -> {.f32, .f32}
. That would avoid having to deal with FP ops accepting integers as inputs.: https://godbolt.org/z/64d9of9d6
mov.b64 {%f1, %f3}, %rd2;
mov.b64 {%f2, %f4}, %rd3;
add.rn.f32 %f5, %f1, %f2;
add.rn.f32 %f6, %f3, %f4;
add.rn.f32 %f7, %f5, %f6;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I completely agree we're producing pretty unsightly PTX here, even if the SASS is efficient.
Add a new getLoadSliceCost target hook which converts information from a LoadSlice::Cost into a scalar value for comparison. Override this for NVPTX to treat CrossRegisterBanksCopies as free in order to prevent harmful load splitting.