Skip to content

[DAGCombiner] Don't fold cheap extracts of multiple use splats #134120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Apr 2, 2025

For out-of loop sum reductions, the loop vectorizer will emit an initial reduction vector like this in the preheader, where there might be an initial value to begin summing from:

%v = insertelement <vscale x 4 x i32> zeroinitializer, i32 %initial, i64 0

On RISC-V we currently lower this quite poorly with two splats of 0, one at m1 and one at m2:

vsetvli a1, zero, e32, m1, ta, ma  
vmv.v.i v10, 0                     
vsetvli zero, zero, e32, m1, tu, ma
vmv.s.x v10, a0                    
vsetvli a0, zero, e32, m2, ta, ma  
vmv.v.i v8, 0                      
vmv1r.v v8, v10                    

The underlying reason is that we fold ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V) even if the splat has multiple uses:

  t5: nxv4i32 = splat_vector Constant:i64<0>
        t12: nxv2i32 = extract_subvector t5, Constant:i64<0>
        t2: i64,ch = CopyFromReg t0, Register:i64 %0
      t15: nxv2i32 = RISCVISD::VMV_S_X_VL t12, t2, Register:i64 $x0
    t16: nxv4i32 = insert_subvector t5, t15, Constant:i64<0>

From what I understand on AArch64 this is fine because ty1 splat(V) and ty2 splat(V) will produce the same MachineInstr in the same register class, which can be deduplicated.

On RISC-V the splats might have two separate register classes so we end up creating a second splat.

This fixes this by creating a new splat only if the extract isn't cheap. If an extract is cheap it should just be a subregister extract, which shouldn't introduce any instructions on RISC-V, so there's no benefit to folding.

This on its own caused regressions on AArch64: I've included two commits here so reviewers can see them. The second commit works around it by enabling the AArch64 specific combine in more places, without the TLI.isExtractSubvectorCheap restriction. This was a quick hack and there's likely a better way to fix this, so I'm open to any suggestions

@llvmbot
Copy link
Member

llvmbot commented Apr 2, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Luke Lau (lukel97)

Changes

For out-of loop sum reductions, the loop vectorizer will emit an initial reduction vector like this in the preheader, where there might be an initial value to begin summing from:

%v = insertelement &lt;vscale x 4 x i32&gt; zeroinitializer, i32 %initial, i64 0

On RISC-V we currently lower this quite poorly with two splats of 0, one at m1 and one at m2:

vsetvli a1, zero, e32, m1, ta, ma  
vmv.v.i v10, 0                     
vsetvli zero, zero, e32, m1, tu, ma
vmv.s.x v10, a0                    
vsetvli a0, zero, e32, m2, ta, ma  
vmv.v.i v8, 0                      
vmv1r.v v8, v10                    

The underlying reason is that we fold ty1 extract_vector(ty2 splat(V))) -&gt; ty1 splat(V) even if the splat has multiple uses:

  t5: nxv4i32 = splat_vector Constant:i64&lt;0&gt;
        t12: nxv2i32 = extract_subvector t5, Constant:i64&lt;0&gt;
        t2: i64,ch = CopyFromReg t0, Register:i64 %0
      t15: nxv2i32 = RISCVISD::VMV_S_X_VL t12, t2, Register:i64 $x0
    t16: nxv4i32 = insert_subvector t5, t15, Constant:i64&lt;0&gt;

From what I understand on AArch64 this is fine because ty1 splat(V) and ty2 splat(V) will produce the same MachineInstr in the same register class, which can be deduplicated.

On RISC-V the splats might have two separate register classes so we end up creating a second splat.

This fixes this by creating a new splat only if the extract isn't cheap. If an extract is cheap it should just be a subregister extract, which shouldn't introduce any instructions on RISC-V, so there's no benefit to folding.

This on its own caused regressions on AArch64: I've included two commits here so reviewers can see them. The second commit works around it by enabling the AArch64 specific combine in more places, without the TLI.isExtractSubvectorCheap restriction. This was a quick hack and there's likely a better way to fix this, so I'm open to any suggestions


Full diff: https://github.com/llvm/llvm-project/pull/134120.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+5-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll (+6-12)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index dc5c5f38e3bd8..9f0a1ecbe27fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -25383,8 +25383,10 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
   if (V.getOpcode() == ISD::SPLAT_VECTOR)
     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
-      if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
-        return DAG.getSplatVector(NVT, DL, V.getOperand(0));
+      if (!TLI.isExtractSubvectorCheap(NVT, V.getValueType(), ExtIdx) ||
+          V.hasOneUse())
+        if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
+          return DAG.getSplatVector(NVT, DL, V.getOperand(0));
 
   // extract_subvector(insert_subvector(x,y,c1),c2)
   //  --> extract_subvector(y,c2-c1)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e0be0d83f7513..592d5aebff97c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20223,17 +20223,18 @@ performExtractSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     return SDValue();
 
   EVT VT = N->getValueType(0);
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1)
-    return SDValue();
-
   SDValue V = N->getOperand(0);
 
+  if (VT.isScalableVector() != V->getValueType(0).isScalableVector())
+    return SDValue();
+
   // NOTE: This combine exists in DAGCombiner, but that version's legality check
   // blocks this combine because the non-const case requires custom lowering.
+  // We also want to perform it even when the splat has multiple uses.
   //
   // ty1 extract_vector(ty2 splat(const))) -> ty1 splat(const)
   if (V.getOpcode() == ISD::SPLAT_VECTOR)
-    if (isa<ConstantSDNode>(V.getOperand(0)))
+    if (isa<ConstantSDNode, ConstantFPSDNode>(V.getOperand(0)))
       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V.getOperand(0));
 
   return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll b/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll
index 0e43cbf0f4518..2d5216c97d397 100644
--- a/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll
@@ -761,13 +761,10 @@ define <vscale x 8 x i64> @insertelt_nxv8i64_idx(<vscale x 8 x i64> %v, i64 %elt
 define <vscale x 4 x i32> @insertelt_nxv4i32_zeroinitializer_0(i32 %x) {
 ; CHECK-LABEL: insertelt_nxv4i32_zeroinitializer_0:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.s.x v10, a0
-; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-NEXT:    vsetvli a1, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
-; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
 ; CHECK-NEXT:    ret
   %v = insertelement <vscale x 4 x i32> zeroinitializer, i32 %x, i64 0
   ret <vscale x 4 x i32> %v
@@ -776,14 +773,11 @@ define <vscale x 4 x i32> @insertelt_nxv4i32_zeroinitializer_0(i32 %x) {
 define <vscale x 4 x i32> @insertelt_imm_nxv4i32_zeroinitializer_0(i32 %x) {
 ; CHECK-LABEL: insertelt_imm_nxv4i32_zeroinitializer_0:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    li a0, 42
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.s.x v10, a0
 ; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
-; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
 ; CHECK-NEXT:    ret
   %v = insertelement <vscale x 4 x i32> zeroinitializer, i32 42, i64 0
   ret <vscale x 4 x i32> %v

@llvmbot
Copy link
Member

llvmbot commented Apr 2, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Luke Lau (lukel97)

Changes

For out-of loop sum reductions, the loop vectorizer will emit an initial reduction vector like this in the preheader, where there might be an initial value to begin summing from:

%v = insertelement &lt;vscale x 4 x i32&gt; zeroinitializer, i32 %initial, i64 0

On RISC-V we currently lower this quite poorly with two splats of 0, one at m1 and one at m2:

vsetvli a1, zero, e32, m1, ta, ma  
vmv.v.i v10, 0                     
vsetvli zero, zero, e32, m1, tu, ma
vmv.s.x v10, a0                    
vsetvli a0, zero, e32, m2, ta, ma  
vmv.v.i v8, 0                      
vmv1r.v v8, v10                    

The underlying reason is that we fold ty1 extract_vector(ty2 splat(V))) -&gt; ty1 splat(V) even if the splat has multiple uses:

  t5: nxv4i32 = splat_vector Constant:i64&lt;0&gt;
        t12: nxv2i32 = extract_subvector t5, Constant:i64&lt;0&gt;
        t2: i64,ch = CopyFromReg t0, Register:i64 %0
      t15: nxv2i32 = RISCVISD::VMV_S_X_VL t12, t2, Register:i64 $x0
    t16: nxv4i32 = insert_subvector t5, t15, Constant:i64&lt;0&gt;

From what I understand on AArch64 this is fine because ty1 splat(V) and ty2 splat(V) will produce the same MachineInstr in the same register class, which can be deduplicated.

On RISC-V the splats might have two separate register classes so we end up creating a second splat.

This fixes this by creating a new splat only if the extract isn't cheap. If an extract is cheap it should just be a subregister extract, which shouldn't introduce any instructions on RISC-V, so there's no benefit to folding.

This on its own caused regressions on AArch64: I've included two commits here so reviewers can see them. The second commit works around it by enabling the AArch64 specific combine in more places, without the TLI.isExtractSubvectorCheap restriction. This was a quick hack and there's likely a better way to fix this, so I'm open to any suggestions


Full diff: https://github.com/llvm/llvm-project/pull/134120.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+5-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll (+6-12)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index dc5c5f38e3bd8..9f0a1ecbe27fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -25383,8 +25383,10 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
   if (V.getOpcode() == ISD::SPLAT_VECTOR)
     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
-      if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
-        return DAG.getSplatVector(NVT, DL, V.getOperand(0));
+      if (!TLI.isExtractSubvectorCheap(NVT, V.getValueType(), ExtIdx) ||
+          V.hasOneUse())
+        if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
+          return DAG.getSplatVector(NVT, DL, V.getOperand(0));
 
   // extract_subvector(insert_subvector(x,y,c1),c2)
   //  --> extract_subvector(y,c2-c1)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e0be0d83f7513..592d5aebff97c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20223,17 +20223,18 @@ performExtractSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     return SDValue();
 
   EVT VT = N->getValueType(0);
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1)
-    return SDValue();
-
   SDValue V = N->getOperand(0);
 
+  if (VT.isScalableVector() != V->getValueType(0).isScalableVector())
+    return SDValue();
+
   // NOTE: This combine exists in DAGCombiner, but that version's legality check
   // blocks this combine because the non-const case requires custom lowering.
+  // We also want to perform it even when the splat has multiple uses.
   //
   // ty1 extract_vector(ty2 splat(const))) -> ty1 splat(const)
   if (V.getOpcode() == ISD::SPLAT_VECTOR)
-    if (isa<ConstantSDNode>(V.getOperand(0)))
+    if (isa<ConstantSDNode, ConstantFPSDNode>(V.getOperand(0)))
       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V.getOperand(0));
 
   return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll b/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll
index 0e43cbf0f4518..2d5216c97d397 100644
--- a/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/insertelt-int-rv64.ll
@@ -761,13 +761,10 @@ define <vscale x 8 x i64> @insertelt_nxv8i64_idx(<vscale x 8 x i64> %v, i64 %elt
 define <vscale x 4 x i32> @insertelt_nxv4i32_zeroinitializer_0(i32 %x) {
 ; CHECK-LABEL: insertelt_nxv4i32_zeroinitializer_0:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.s.x v10, a0
-; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-NEXT:    vsetvli a1, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
-; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
 ; CHECK-NEXT:    ret
   %v = insertelement <vscale x 4 x i32> zeroinitializer, i32 %x, i64 0
   ret <vscale x 4 x i32> %v
@@ -776,14 +773,11 @@ define <vscale x 4 x i32> @insertelt_nxv4i32_zeroinitializer_0(i32 %x) {
 define <vscale x 4 x i32> @insertelt_imm_nxv4i32_zeroinitializer_0(i32 %x) {
 ; CHECK-LABEL: insertelt_imm_nxv4i32_zeroinitializer_0:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    li a0, 42
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmv.s.x v10, a0
 ; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
-; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
 ; CHECK-NEXT:    ret
   %v = insertelement <vscale x 4 x i32> zeroinitializer, i32 42, i64 0
   ret <vscale x 4 x i32> %v

Comment on lines +20228 to +20229
if (VT.isScalableVector() != V->getValueType(0).isScalableVector())
return SDValue();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to prevent an infinite loop where fixed length splats got legalized to scalable + an fixed extract

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a explanatory comment to explain this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one minor

Comment on lines +20228 to +20229
if (VT.isScalableVector() != V->getValueType(0).isScalableVector())
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a explanatory comment to explain this?

@lukel97
Copy link
Contributor Author

lukel97 commented Apr 3, 2025

It looks like this is also causing some regressions in hexagon. Taking a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants