-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[LoongArch] Lower build_vector to broadcast load if possible #135896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-loongarch Author: None (tangaac) ChangesFull diff: https://github.com/llvm/llvm-project/pull/135896.diff 5 Files Affected:
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 002d88cbeeba3..8c5e095dea039 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -1721,6 +1721,47 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
return false;
}
+// Lower BUILD_VECTOR as broadcast load (if possible).
+// For example:
+// %a = load i8, ptr %ptr
+// %b = build_vector %a, %a, %a, %a
+// is lowered to :
+// (VLDREPL_B $a0, 0)
+static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
+ const SDLoc &DL,
+ SelectionDAG &DAG) {
+ MVT VT = BVOp->getSimpleValueType(0);
+ int NumOps = BVOp->getNumOperands();
+
+ assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
+ "Unsupported vector type for broadcast.");
+
+ SDValue IdentitySrc;
+ bool IsIdeneity = true;
+
+ for (int i = 0; i != NumOps; i++) {
+ SDValue Op = BVOp->getOperand(i);
+ if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
+ IsIdeneity = false;
+ break;
+ }
+ IdentitySrc = BVOp->getOperand(0);
+ }
+
+ if (IsIdeneity) {
+ auto *LN = cast<LoadSDNode>(IdentitySrc);
+ SDVTList Tys =
+ LN->isIndexed()
+ ? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
+ : DAG.getVTList(VT, MVT::Other);
+ SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
+ SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
+ return BCast;
+ }
+ return SDValue();
+}
+
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
@@ -1736,6 +1777,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
(!Subtarget.hasExtLASX() || !Is256Vec))
return SDValue();
+ if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
+ return Result;
+
if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
/*MinSplatBits=*/8) &&
SplatBitSize <= 64) {
@@ -5171,6 +5215,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VSRLI)
NODE_NAME_CASE(VBSLL)
NODE_NAME_CASE(VBSRL)
+ NODE_NAME_CASE(VLDREPL)
}
#undef NODE_NAME_CASE
return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 52d88b9b24a6b..71243a4f0d708 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -155,7 +155,10 @@ enum NodeType : unsigned {
// Vector byte logicial left / right shift
VBSLL,
- VBSRL
+ VBSRL,
+
+ // Scalar load broadcast to vector
+ VLDREPL
// Intrinsic operations end =============================================
};
diff --git a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
index e4feaa600c57d..775d9289af7c4 100644
--- a/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchInstrInfo.td
@@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
}
}
-def simm9_lsl3 : Operand<GRLenVT> {
+def simm9_lsl3 : Operand<GRLenVT>,
+ ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
let EncoderMethod = "getImmOpValueAsr<3>";
let DecoderMethod = "decodeSImmOperand<9, 3>";
@@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
let ParserMatchClass = SImmAsmOperand<10>;
}
-def simm10_lsl2 : Operand<GRLenVT> {
+def simm10_lsl2 : Operand<GRLenVT>,
+ ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
let EncoderMethod = "getImmOpValueAsr<2>";
let DecoderMethod = "decodeSImmOperand<10, 2>";
}
-def simm11_lsl1 : Operand<GRLenVT> {
+def simm11_lsl1 : Operand<GRLenVT>,
+ ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
let EncoderMethod = "getImmOpValueAsr<1>";
let DecoderMethod = "decodeSImmOperand<11, 1>";
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index d6d532cddb594..54fad8421378b 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -2161,6 +2161,7 @@ def : Pat<(int_loongarch_lasx_xvld GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lasx_xvldx GPR:$rj, GPR:$rk),
(XVLDX GPR:$rj, GPR:$rk)>;
+// xvldrepl
def : Pat<(int_loongarch_lasx_xvldrepl_b GPR:$rj, timm:$imm),
(XVLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
def : Pat<(int_loongarch_lasx_xvldrepl_h GPR:$rj, timm:$imm),
@@ -2170,6 +2171,11 @@ def : Pat<(int_loongarch_lasx_xvldrepl_w GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lasx_xvldrepl_d GPR:$rj, timm:$imm),
(XVLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
+defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
+defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
+defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
+
// store
def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
(XVST LASX256:$xd, GPR:$rj, (to_valid_timm timm:$imm))>;
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index b0d880749bf92..2b44361df29ba 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -26,6 +26,7 @@ def SDT_LoongArchV1RUimm: SDTypeProfile<1, 2, [SDTCisVec<0>,
def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>]>;
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
+def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
// Target nodes.
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -64,6 +65,10 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
+def loongarch_vldrepl
+ : SDNode<"LoongArchISD::VLDREPL",
+ SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1433,6 +1438,14 @@ multiclass PatCCVrVrF<CondCode CC, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
}
+multiclass VldreplPat<ValueType vt, LAInst Inst, Operand ImmOpnd> {
+ def : Pat<(vt(loongarch_vldrepl BaseAddr:$rj)), (Inst BaseAddr:$rj, 0)>;
+ def : Pat<(vt(loongarch_vldrepl(AddrConstant GPR:$rj, ImmOpnd:$imm))),
+ (Inst GPR:$rj, ImmOpnd:$imm)>;
+ def : Pat<(vt(loongarch_vldrepl(AddLike BaseAddr:$rj, ImmOpnd:$imm))),
+ (Inst BaseAddr:$rj, ImmOpnd:$imm)>;
+}
+
let Predicates = [HasExtLSX] in {
// VADD_{B/H/W/D}
@@ -2338,6 +2351,7 @@ def : Pat<(int_loongarch_lsx_vld GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lsx_vldx GPR:$rj, GPR:$rk),
(VLDX GPR:$rj, GPR:$rk)>;
+// vldrepl
def : Pat<(int_loongarch_lsx_vldrepl_b GPR:$rj, timm:$imm),
(VLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
def : Pat<(int_loongarch_lsx_vldrepl_h GPR:$rj, timm:$imm),
@@ -2347,6 +2361,11 @@ def : Pat<(int_loongarch_lsx_vldrepl_w GPR:$rj, timm:$imm),
def : Pat<(int_loongarch_lsx_vldrepl_d GPR:$rj, timm:$imm),
(VLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
+defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
+defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
+defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
+defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
+
// store
def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
(VST LSX128:$vd, GPR:$rj, (to_valid_timm timm:$imm))>;
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
e412c72
to
e7f09af
Compare
Here are the files optimized by this pr on llvm-test-suite: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/18/builds/14794 Here is the relevant piece of the build log for the reference
|
No description provided.