Skip to content

[NVPTX] Add TLI hook for load slice cost and implement it #131847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3133,6 +3133,20 @@ class TargetLoweringBase {
return false;
}

virtual unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
unsigned CrossRegisterBanksCopies,
unsigned Truncates, unsigned ZExts,
unsigned Shifts) const {
// Assume cross register banks copies are as expensive as loads.
unsigned ExpensiveOps = Loads + CrossRegisterBanksCopies;

// Unless we are optimizing for code size, prioritize expensive operations.
if (!ForCodeSize)
ExpensiveOps = ExpensiveOps * 20;
Comment on lines +3143 to +3145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit suggestion: "prioritize expensive operations." -> "prefer avoiding expensive ops" or "prefer less expensive ops".

Where does the magic cost scaling factor of 20 come from? I can't match it to the logic this weight function replaces.
It would be prudent to make the new defaults behave close to how they worked before the change. This is in the code path that will likely affect almost everything. We don't want any surprises.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 20 is quite arbitrary. In the previous implementation we first compared expensive ops and only checked inexpensive ops if expensive ops were equal. This meant that expensive ops always trumped inexpensive ops. I don't think that previous logic really made sense as there must be some number of inexpensive ops that is more costly than an expensive op. I chose 20 because in practice I don't think there are going to be any cases where there are 20 inexpensive ops, so this should essentially get us the same behavior. I agree this is a bit weird though, I'm happy to try out any alternatives you can suggest.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cost heuristics are prone to cause unexpected surprises, and I do not see a good way to validate the currently selected value without exposing it to the real world use cases. Perhaps we can make that scaling factor tunable as an escape hatch in case the current value produces suboptimal results for some users.


return Truncates + ZExts + Shifts + ExpensiveOps;
}

/// Return true if the target has a vector blend instruction.
virtual bool hasVectorBlend() const { return false; }

Expand Down
36 changes: 8 additions & 28 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19753,32 +19753,10 @@ struct LoadedSlice {
return *this;
}

bool operator==(const Cost &RHS) const {
return Loads == RHS.Loads && Truncates == RHS.Truncates &&
CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
ZExts == RHS.ZExts && Shift == RHS.Shift;
unsigned value(const TargetLowering &TLI) const {
return TLI.getLoadSliceCost(ForCodeSize, Loads, CrossRegisterBanksCopies,
Truncates, ZExts, Shift);
}

bool operator!=(const Cost &RHS) const { return !(*this == RHS); }

bool operator<(const Cost &RHS) const {
// Assume cross register banks copies are as expensive as loads.
// FIXME: Do we want some more target hooks?
unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
// Unless we are optimizing for code size, consider the
// expensive operation first.
if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
return ExpensiveOpsLHS < ExpensiveOpsRHS;
return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
(RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
}

bool operator>(const Cost &RHS) const { return RHS < *this; }

bool operator<=(const Cost &RHS) const { return !(RHS < *this); }

bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
};

// The last instruction that represent the slice. This should be a
Expand Down Expand Up @@ -20099,7 +20077,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
/// FIXME: When the cost model will be mature enough, we can relax
/// constraints (1) and (2).
static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
const APInt &UsedBits, bool ForCodeSize) {
const APInt &UsedBits, bool ForCodeSize,
const TargetLowering &TLI) {
unsigned NumberOfSlices = LoadedSlices.size();
if (StressLoadSlicing)
return NumberOfSlices > 1;
Expand Down Expand Up @@ -20129,7 +20108,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,

// If the target supports paired load, adjust the cost accordingly.
adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
return OrigCost > GlobalSlicingCost;
return OrigCost.value(TLI) > GlobalSlicingCost.value(TLI);
}

/// If the given load, \p LI, is used only by trunc or trunc(lshr)
Expand Down Expand Up @@ -20209,7 +20188,8 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
}

// Abort slicing if it does not seem to be profitable.
if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize,
DAG.getTargetLoweringInfo()))
return false;

++SlicedLoads;
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4735,6 +4735,23 @@ bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
return true;
}

unsigned NVPTXTargetLowering::getLoadSliceCost(
bool ForCodeSize, unsigned Loads, unsigned CrossRegisterBanksCopies,
unsigned Truncates, unsigned ZExts, unsigned Shifts) const {

// Loads are much more expensive than other operations, and the cost of extra
// load is not offset by savings from shift/mask if the usage of the load is
// as split elements.
//
// Base TLI treats CrossRegisterBanksCopies as expensive, but these operations
// can be optimized in most cases for NVPTX.
//
CrossRegisterBanksCopies = 0;

return TargetLoweringBase::getLoadSliceCost(
ForCodeSize, Loads, CrossRegisterBanksCopies, Truncates, ZExts, Shifts);
}

//===----------------------------------------------------------------------===//
// NVPTX Inline Assembly Support
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ class NVPTXTargetLowering : public TargetLowering {
unsigned AS,
Instruction *I = nullptr) const override;

unsigned getLoadSliceCost(bool ForCodeSize, unsigned Loads,
unsigned CrossRegisterBanksCopies,
unsigned Truncates, unsigned ZExts,
unsigned Shifts) const override;

bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
// Truncating 64-bit to 32-bit is free in SASS.
if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
Expand Down
54 changes: 54 additions & 0 deletions llvm/test/CodeGen/NVPTX/load-slice.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s | FileCheck %s

target triple = "nvptx64-unknown-unknown"

;; Verify that 64-bit loads are not split into more 32-bit
;; loads. Loads are more expensive than shifts/conversions.
define float @test(ptr %in) {
;
; CHECK-LABEL: test(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .f32 %f<8>;
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [test_param_0];
; CHECK-NEXT: ld.u64 %rd2, [%rd1];
; CHECK-NEXT: ld.u64 %rd3, [%rd1+8];
; CHECK-NEXT: cvt.u32.u64 %r1, %rd2;
; CHECK-NEXT: cvt.u32.u64 %r2, %rd3;
; CHECK-NEXT: mov.b32 %f1, %r1;
; CHECK-NEXT: mov.b32 %f2, %r2;
; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd2; }
; CHECK-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r4}, %rd3; }
; CHECK-NEXT: mov.b32 %f4, %r3;
; CHECK-NEXT: mov.b32 %f5, %r4;
Comment on lines +20 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future cleanup opportunity: We seem to be using too many instructions to do effectively nothing on the SASS level. FP operations accept b32 registers as inputs, so we can skip the moves between .b32 and .f32 registers and just do add.rn.f32 %f3, %r1, %r2; That also allows splitting 64-bit values in one mov operation.

https://godbolt.org/z/1bb5dMGsa

    mov.b64 {%r1, %r3}, %rd2;
    mov.b64 {%r2, %r4}, %rd3;
    add.rn.f32 %f3, %r1, %r2;
    add.rn.f32 %f6, %r3, %r4;
    add.rn.f32 %f7, %f3, %f6;

Or, alternatively, split .b64 -> {.f32, .f32}. That would avoid having to deal with FP ops accepting integers as inputs.: https://godbolt.org/z/64d9of9d6

    mov.b64 {%f1, %f3}, %rd2;
    mov.b64 {%f2, %f4}, %rd3;
    add.rn.f32 %f5, %f1, %f2;
    add.rn.f32 %f6, %f3, %f4;
    add.rn.f32 %f7, %f5, %f6;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I completely agree we're producing pretty unsightly PTX here, even if the SASS is efficient.

; CHECK-NEXT: add.rn.f32 %f6, %f4, %f5;
; CHECK-NEXT: add.rn.f32 %f7, %f3, %f6;
; CHECK-NEXT: st.param.f32 [func_retval0], %f7;
; CHECK-NEXT: ret;
%ptr0 = getelementptr inbounds i64, ptr %in, i64 0
%ptr1 = getelementptr inbounds i64, ptr %in, i64 1

%load0 = load i64, ptr %ptr0, align 8
%load1 = load i64, ptr %ptr1, align 8
%trunc_lo_0 = trunc i64 %load0 to i32
%trunc_lo_1 = trunc i64 %load1 to i32
%float_lo_0 = bitcast i32 %trunc_lo_0 to float
%float_lo_1 = bitcast i32 %trunc_lo_1 to float
%add_lo = fadd float %float_lo_0, %float_lo_1

%shift0 = lshr i64 %load0, 32
%shift1 = lshr i64 %load1, 32
%trunc_hi_0 = trunc i64 %shift0 to i32
%trunc_hi_1 = trunc i64 %shift1 to i32
%float_hi_0 = bitcast i32 %trunc_hi_0 to float
%float_hi_1 = bitcast i32 %trunc_hi_1 to float
%add_hi = fadd float %float_hi_0, %float_hi_1

%res = fadd float %add_lo, %add_hi
ret float %res
}