Skip to content

[LoongArch] Lower build_vector to broadcast load if possible #135896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,51 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
return false;
}

// Lower BUILD_VECTOR as broadcast load (if possible).
// For example:
// %a = load i8, ptr %ptr
// %b = build_vector %a, %a, %a, %a
// is lowered to :
// (VLDREPL_B $a0, 0)
static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
const SDLoc &DL,
SelectionDAG &DAG) {
MVT VT = BVOp->getSimpleValueType(0);
int NumOps = BVOp->getNumOperands();

assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
"Unsupported vector type for broadcast.");

SDValue IdentitySrc;
bool IsIdeneity = true;

for (int i = 0; i != NumOps; i++) {
SDValue Op = BVOp->getOperand(i);
if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
IsIdeneity = false;
break;
}
IdentitySrc = BVOp->getOperand(0);
}

// make sure that this load is valid and only has one user.
if (!IdentitySrc || !BVOp->isOnlyUserOf(IdentitySrc.getNode()))
return SDValue();

if (IsIdeneity) {
auto *LN = cast<LoadSDNode>(IdentitySrc);
SDVTList Tys =
LN->isIndexed()
? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
: DAG.getVTList(VT, MVT::Other);
SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
return BCast;
}
return SDValue();
}

SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
Expand All @@ -1891,6 +1936,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
(!Subtarget.hasExtLASX() || !Is256Vec))
return SDValue();

if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
return Result;

if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
/*MinSplatBits=*/8) &&
SplatBitSize <= 64) {
Expand Down Expand Up @@ -5326,6 +5374,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VSRLI)
NODE_NAME_CASE(VBSLL)
NODE_NAME_CASE(VBSRL)
NODE_NAME_CASE(VLDREPL)
}
#undef NODE_NAME_CASE
return nullptr;
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ enum NodeType : unsigned {

// Vector byte logicial left / right shift
VBSLL,
VBSRL
VBSRL,

// Scalar load broadcast to vector
VLDREPL

// Intrinsic operations end =============================================
};
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
}
}

def simm9_lsl3 : Operand<GRLenVT> {
def simm9_lsl3 : Operand<GRLenVT>,
ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
let EncoderMethod = "getImmOpValueAsr<3>";
let DecoderMethod = "decodeSImmOperand<9, 3>";
Expand All @@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
let ParserMatchClass = SImmAsmOperand<10>;
}

def simm10_lsl2 : Operand<GRLenVT> {
def simm10_lsl2 : Operand<GRLenVT>,
ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
let EncoderMethod = "getImmOpValueAsr<2>";
let DecoderMethod = "decodeSImmOperand<10, 2>";
}

def simm11_lsl1 : Operand<GRLenVT> {
def simm11_lsl1 : Operand<GRLenVT>,
ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
let EncoderMethod = "getImmOpValueAsr<1>";
let DecoderMethod = "decodeSImmOperand<11, 1>";
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,7 @@ def : Pat<(int_loongarch_lasx_xvld GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lasx_xvldx GPR:$rj, GPR:$rk),
(XVLDX GPR:$rj, GPR:$rk)>;

// xvldrepl
def : Pat<(int_loongarch_lasx_xvldrepl_b GPR:$rj, timm:$imm),
(XVLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
def : Pat<(int_loongarch_lasx_xvldrepl_h GPR:$rj, timm:$imm),
Expand All @@ -2174,6 +2175,13 @@ def : Pat<(int_loongarch_lasx_xvldrepl_w GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lasx_xvldrepl_d GPR:$rj, timm:$imm),
(XVLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;

defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
defm : VldreplPat<v8f32, XVLDREPL_W, simm10_lsl2>;
defm : VldreplPat<v4f64, XVLDREPL_D, simm9_lsl3>;

// store
def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
(XVST LASX256:$xd, GPR:$rj, (to_valid_timm timm:$imm))>;
Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def SDT_LoongArchV1RUimm: SDTypeProfile<1, 2, [SDTCisVec<0>,
def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>]>;
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;

// Target nodes.
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
Expand Down Expand Up @@ -64,6 +65,10 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;

def loongarch_vldrepl
: SDNode<"LoongArchISD::VLDREPL",
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;

def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
Expand Down Expand Up @@ -1433,6 +1438,14 @@ multiclass PatCCVrVrF<CondCode CC, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
}

multiclass VldreplPat<ValueType vt, LAInst Inst, Operand ImmOpnd> {
def : Pat<(vt(loongarch_vldrepl BaseAddr:$rj)), (Inst BaseAddr:$rj, 0)>;
def : Pat<(vt(loongarch_vldrepl(AddrConstant GPR:$rj, ImmOpnd:$imm))),
(Inst GPR:$rj, ImmOpnd:$imm)>;
def : Pat<(vt(loongarch_vldrepl(AddLike BaseAddr:$rj, ImmOpnd:$imm))),
(Inst BaseAddr:$rj, ImmOpnd:$imm)>;
}

let Predicates = [HasExtLSX] in {

// VADD_{B/H/W/D}
Expand Down Expand Up @@ -2342,6 +2355,7 @@ def : Pat<(int_loongarch_lsx_vld GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lsx_vldx GPR:$rj, GPR:$rk),
(VLDX GPR:$rj, GPR:$rk)>;

// vldrepl
def : Pat<(int_loongarch_lsx_vldrepl_b GPR:$rj, timm:$imm),
(VLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
def : Pat<(int_loongarch_lsx_vldrepl_h GPR:$rj, timm:$imm),
Expand All @@ -2351,6 +2365,13 @@ def : Pat<(int_loongarch_lsx_vldrepl_w GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lsx_vldrepl_d GPR:$rj, timm:$imm),
(VLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;

defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
defm : VldreplPat<v4f32, VLDREPL_W, simm10_lsl2>;
defm : VldreplPat<v2f64, VLDREPL_D, simm9_lsl3>;

// store
def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
(VST LSX128:$vd, GPR:$rj, (to_valid_timm timm:$imm))>;
Expand Down
40 changes: 14 additions & 26 deletions llvm/test/CodeGen/LoongArch/lasx/broadcast-load.ll
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ define <4 x i64> @should_not_be_optimized(ptr %ptr, ptr %dst) {
define <4 x i64> @xvldrepl_d_unaligned_offset(ptr %ptr) {
; CHECK-LABEL: xvldrepl_d_unaligned_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.d $a0, $a0, 4
; CHECK-NEXT: xvreplgr2vr.d $xr0, $a0
; CHECK-NEXT: addi.d $a0, $a0, 4
; CHECK-NEXT: xvldrepl.d $xr0, $a0, 0
; CHECK-NEXT: ret
%p = getelementptr i32, ptr %ptr, i32 1
%tmp = load i64, ptr %p
Expand All @@ -34,8 +34,7 @@ define <4 x i64> @xvldrepl_d_unaligned_offset(ptr %ptr) {
define <32 x i8> @xvldrepl_b(ptr %ptr) {
; CHECK-LABEL: xvldrepl_b:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.b $a0, $a0, 0
; CHECK-NEXT: xvreplgr2vr.b $xr0, $a0
; CHECK-NEXT: xvldrepl.b $xr0, $a0, 0
; CHECK-NEXT: ret
%tmp = load i8, ptr %ptr
%tmp1 = insertelement <32 x i8> zeroinitializer, i8 %tmp, i32 0
Expand All @@ -46,8 +45,7 @@ define <32 x i8> @xvldrepl_b(ptr %ptr) {
define <32 x i8> @xvldrepl_b_offset(ptr %ptr) {
; CHECK-LABEL: xvldrepl_b_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.b $a0, $a0, 33
; CHECK-NEXT: xvreplgr2vr.b $xr0, $a0
; CHECK-NEXT: xvldrepl.b $xr0, $a0, 33
; CHECK-NEXT: ret
%p = getelementptr i8, ptr %ptr, i64 33
%tmp = load i8, ptr %p
Expand All @@ -60,8 +58,7 @@ define <32 x i8> @xvldrepl_b_offset(ptr %ptr) {
define <16 x i16> @xvldrepl_h(ptr %ptr) {
; CHECK-LABEL: xvldrepl_h:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.h $a0, $a0, 0
; CHECK-NEXT: xvreplgr2vr.h $xr0, $a0
; CHECK-NEXT: xvldrepl.h $xr0, $a0, 0
; CHECK-NEXT: ret
%tmp = load i16, ptr %ptr
%tmp1 = insertelement <16 x i16> zeroinitializer, i16 %tmp, i32 0
Expand All @@ -72,8 +69,7 @@ define <16 x i16> @xvldrepl_h(ptr %ptr) {
define <16 x i16> @xvldrepl_h_offset(ptr %ptr) {
; CHECK-LABEL: xvldrepl_h_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.h $a0, $a0, 66
; CHECK-NEXT: xvreplgr2vr.h $xr0, $a0
; CHECK-NEXT: xvldrepl.h $xr0, $a0, 66
; CHECK-NEXT: ret
%p = getelementptr i16, ptr %ptr, i64 33
%tmp = load i16, ptr %p
Expand All @@ -85,8 +81,7 @@ define <16 x i16> @xvldrepl_h_offset(ptr %ptr) {
define <8 x i32> @xvldrepl_w(ptr %ptr) {
; CHECK-LABEL: xvldrepl_w:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.w $a0, $a0, 0
; CHECK-NEXT: xvreplgr2vr.w $xr0, $a0
; CHECK-NEXT: xvldrepl.w $xr0, $a0, 0
; CHECK-NEXT: ret
%tmp = load i32, ptr %ptr
%tmp1 = insertelement <8 x i32> zeroinitializer, i32 %tmp, i32 0
Expand All @@ -97,8 +92,7 @@ define <8 x i32> @xvldrepl_w(ptr %ptr) {
define <8 x i32> @xvldrepl_w_offset(ptr %ptr) {
; CHECK-LABEL: xvldrepl_w_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.w $a0, $a0, 132
; CHECK-NEXT: xvreplgr2vr.w $xr0, $a0
; CHECK-NEXT: xvldrepl.w $xr0, $a0, 132
; CHECK-NEXT: ret
%p = getelementptr i32, ptr %ptr, i64 33
%tmp = load i32, ptr %p
Expand All @@ -111,8 +105,7 @@ define <8 x i32> @xvldrepl_w_offset(ptr %ptr) {
define <4 x i64> @xvldrepl_d(ptr %ptr) {
; CHECK-LABEL: xvldrepl_d:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.d $a0, $a0, 0
; CHECK-NEXT: xvreplgr2vr.d $xr0, $a0
; CHECK-NEXT: xvldrepl.d $xr0, $a0, 0
; CHECK-NEXT: ret
%tmp = load i64, ptr %ptr
%tmp1 = insertelement <4 x i64> zeroinitializer, i64 %tmp, i32 0
Expand All @@ -123,8 +116,7 @@ define <4 x i64> @xvldrepl_d(ptr %ptr) {
define <4 x i64> @xvldrepl_d_offset(ptr %ptr) {
; CHECK-LABEL: xvldrepl_d_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: ld.d $a0, $a0, 264
; CHECK-NEXT: xvreplgr2vr.d $xr0, $a0
; CHECK-NEXT: xvldrepl.d $xr0, $a0, 264
; CHECK-NEXT: ret
%p = getelementptr i64, ptr %ptr, i64 33
%tmp = load i64, ptr %p
Expand All @@ -136,8 +128,7 @@ define <4 x i64> @xvldrepl_d_offset(ptr %ptr) {
define <8 x float> @vldrepl_w_flt(ptr %ptr) {
; CHECK-LABEL: vldrepl_w_flt:
; CHECK: # %bb.0:
; CHECK-NEXT: fld.s $fa0, $a0, 0
; CHECK-NEXT: xvreplve0.w $xr0, $xr0
; CHECK-NEXT: xvldrepl.w $xr0, $a0, 0
; CHECK-NEXT: ret
%tmp = load float, ptr %ptr
%tmp1 = insertelement <8 x float> zeroinitializer, float %tmp, i32 0
Expand All @@ -148,8 +139,7 @@ define <8 x float> @vldrepl_w_flt(ptr %ptr) {
define <8 x float> @vldrepl_w_flt_offset(ptr %ptr) {
; CHECK-LABEL: vldrepl_w_flt_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: fld.s $fa0, $a0, 264
; CHECK-NEXT: xvreplve0.w $xr0, $xr0
; CHECK-NEXT: xvldrepl.w $xr0, $a0, 264
; CHECK-NEXT: ret
%p = getelementptr i64, ptr %ptr, i64 33
%tmp = load float, ptr %p
Expand All @@ -161,8 +151,7 @@ define <8 x float> @vldrepl_w_flt_offset(ptr %ptr) {
define <4 x double> @vldrepl_d_dbl(ptr %ptr) {
; CHECK-LABEL: vldrepl_d_dbl:
; CHECK: # %bb.0:
; CHECK-NEXT: fld.d $fa0, $a0, 0
; CHECK-NEXT: xvreplve0.d $xr0, $xr0
; CHECK-NEXT: xvldrepl.d $xr0, $a0, 0
; CHECK-NEXT: ret
%tmp = load double, ptr %ptr
%tmp1 = insertelement <4 x double> zeroinitializer, double %tmp, i32 0
Expand All @@ -173,8 +162,7 @@ define <4 x double> @vldrepl_d_dbl(ptr %ptr) {
define <4 x double> @vldrepl_d_dbl_offset(ptr %ptr) {
; CHECK-LABEL: vldrepl_d_dbl_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: fld.d $fa0, $a0, 264
; CHECK-NEXT: xvreplve0.d $xr0, $xr0
; CHECK-NEXT: xvldrepl.d $xr0, $a0, 264
; CHECK-NEXT: ret
%p = getelementptr i64, ptr %ptr, i64 33
%tmp = load double, ptr %p
Expand Down
Loading
Loading