Skip to content

[InstCombine] Combine and->cmp->sel->or-disjoint into and->mul #135274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

jrbyrnes
Copy link
Contributor

While and->cmp->set combines into and->mul may result in worse code on some targets, this combine should be uniformly beneficial.

https://alive2.llvm.org/ce/z/3Dnw2u

Change-Id: Id45315f1e5f71077800d3a8141b85bb3b5d8f38a
%tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
%tmp3 = or <16 x i1> %tmp, %tmp2
ret <16 x i1> %tmp3
%temp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To silence a update_test_checks warning

@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Jeffrey Byrnes (jrbyrnes)

Changes

While and->cmp->set combines into and->mul may result in worse code on some targets, this combine should be uniformly beneficial.

https://alive2.llvm.org/ce/z/3Dnw2u


Full diff: https://github.com/llvm/llvm-project/pull/135274.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+42)
  • (modified) llvm/test/Transforms/InstCombine/or.ll (+110-4)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc241781d112..6dc4b97686f97 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3643,6 +3643,48 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
             foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
+
+    Value *Cond0 = nullptr, *Cond1 = nullptr;
+    ConstantInt *Op0True = nullptr, *Op0False = nullptr;
+    ConstantInt *Op1True = nullptr, *Op1False = nullptr;
+
+    //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+    if (match(I.getOperand(0), m_Select(m_Value(Cond0), m_ConstantInt(Op0True),
+                                        m_ConstantInt(Op0False))) &&
+        match(I.getOperand(1), m_Select(m_Value(Cond1), m_ConstantInt(Op1True),
+                                        m_ConstantInt(Op1False))) &&
+        Op0True->isZero() && Op1True->isZero() &&
+        Op0False->getValue().tryZExtValue() &&
+        Op1False->getValue().tryZExtValue()) {
+      CmpPredicate Pred0, Pred1;
+      Value *CmpOp0 = nullptr, *CmpOp1 = nullptr;
+      ConstantInt *Op0Cond = nullptr, *Op1Cond = nullptr;
+      if (match(Cond0,
+                m_c_ICmp(Pred0, m_Value(CmpOp0), m_ConstantInt(Op0Cond))) &&
+          match(Cond1,
+                m_c_ICmp(Pred1, m_Value(CmpOp1), m_ConstantInt(Op1Cond))) &&
+          Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ &&
+          Op0Cond->isZero() && Op1Cond->isZero()) {
+        Value *AndSrc0 = nullptr, *AndSrc1 = nullptr;
+        ConstantInt *BitSel0 = nullptr, *BitSel1 = nullptr;
+        if (match(CmpOp0, m_And(m_Value(AndSrc0), m_ConstantInt(BitSel0))) &&
+            match(CmpOp1, m_And(m_Value(AndSrc1), m_ConstantInt(BitSel1))) &&
+            AndSrc0 == AndSrc1 && BitSel0->getValue().tryZExtValue() &&
+            BitSel1->getValue().tryZExtValue()) {
+          unsigned Out0 = Op0False->getValue().getZExtValue();
+          unsigned Out1 = Op1False->getValue().getZExtValue();
+          unsigned Sel0 = BitSel0->getValue().getZExtValue();
+          unsigned Sel1 = BitSel1->getValue().getZExtValue();
+          if (!(Out0 % Sel0) && !(Out1 % Sel1) &&
+              ((Out0 / Sel0) == (Out1 / Sel1))) {
+            auto NewAnd = Builder.CreateAnd(
+                AndSrc0, ConstantInt::get(AndSrc0->getType(), Sel0 + Sel1));
+            return BinaryOperator::CreateMul(
+                NewAnd, ConstantInt::get(NewAnd->getType(), (Out1 / Sel1)));
+          }
+        }
+      }
+    }
   }
 
   Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 95f89e4ce11cd..f2b21ca966592 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1281,10 +1281,10 @@ define <16 x i1> @test51(<16 x i1> %arg, <16 x i1> %arg1) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <16 x i1> [[ARG:%.*]], <16 x i1> [[ARG1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 5, i32 6, i32 23, i32 24, i32 9, i32 10, i32 27, i32 28, i32 29, i32 30, i32 31>
 ; CHECK-NEXT:    ret <16 x i1> [[TMP3]]
 ;
-  %tmp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
-  %tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
-  %tmp3 = or <16 x i1> %tmp, %tmp2
-  ret <16 x i1> %tmp3
+  %temp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
+  %temp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
+  %temp3 = or <16 x i1> %temp, %temp2
+  ret <16 x i1> %temp3
 }
 
 ; This would infinite loop because it reaches a transform
@@ -2035,3 +2035,109 @@ define i32 @or_xor_and_commuted3(i32 %x, i32 %y, i32 %z) {
   %or1 = or i32 %xor, %yy
   ret i32 %or1
 }
+
+define i32 @add_select_cmp_and1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 5
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 4
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 288
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and3(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and3(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[BITOP2:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %temp = or disjoint i32 %sel0, %sel1
+  %bitop2 = and i32 %in, 4
+  %cmp2 = icmp eq i32 %bitop2, 0
+  %sel2 = select i1 %cmp2, i32 0, i32 288
+  %out = or disjoint i32 %temp, %sel2
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and4(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and4(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT:    [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[TEMP2]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %temp = or disjoint i32 %sel0, %sel1
+  %bitop2 = and i32 %in, 4
+  %cmp2 = icmp eq i32 %bitop2, 0
+  %bitop3 = and i32 %in, 8
+  %cmp3 = icmp eq i32 %bitop3, 0
+  %sel2 = select i1 %cmp2, i32 0, i32 288
+  %sel3 = select i1 %cmp3, i32 0, i32 576
+  %temp2 = or disjoint i32 %sel2, %sel3
+  %out = or disjoint i32 %temp, %temp2
+  ret i32 %out
+}
+
+
+
+define i32 @add_select_cmp_and_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mismatch(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[BITOP1:%.*]] = and i32 [[IN]], 3
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 288
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 3
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 288
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}

@jrbyrnes
Copy link
Contributor Author

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use fixed-length integers.
Can you please provide a generalized alive2 proof?

Change-Id: I630d506375b0eb4b16dad1437bff2da357be2059
@jrbyrnes
Copy link
Contributor Author

Generalized proof: https://alive2.llvm.org/ce/z/MibAcN

@dtcxzyw dtcxzyw requested a review from andjo403 April 15, 2025 00:53
@@ -96,10 +97,27 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
Pred = ICmpInst::getStrictPredicate(Pred);
}

auto decomposeBitMask =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused why the original code doesn't handle this pattern.
cc @andjo403

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure -- but maybe it has something to do with the thought that and->icmp->sel is the canonical sequence. Thus, we never thought to break the icmp into composite parts for and.

Change-Id: I24786ee6dc53a33fc7afbd80d226cda4e4a4df03
bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
bool LookThroughTrunc = true, bool AllowNonZeroC = false,
bool LookThruBitSel = false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think that the name LookThruBitSel is a bit confusing what is it looking through? maybe something with it matches the bit test pattern also, like allowBitTest or something. but also good to describe it in the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you prefer DecomposeBitMask ?

Change-Id: I897641e07a0e035c106f8d0d018f8a1b7a92aa6b
Change-Id: I2a995aa22358d9d532a7f839dd64d8c8633990fa
Change-Id: I769481b102dc943039b2bd07a815f2716e58fed4
Change-Id: I74ff3a7a1f7eebef1e9194a40c0c49834dd23117
Copy link

github-actions bot commented Apr 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Change-Id: I360e4c6747cf445954143a70e7f8a57ace011dd8
@jrbyrnes jrbyrnes requested a review from dtcxzyw April 17, 2025 20:30
CmpInst::Predicate Pred;
APInt Mask;
APInt C;
};

/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
/// Unless \p AllowNonZeroC is true, C will always be 0.
/// Unless \p AllowNonZeroC is true, C will always be 0. If \p
/// DecomposeBitMask is specified, then, for equality predicates, this will
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend landing this change separately. Then we can remove duplicate logic in other places.

// Try to match/decompose into: icmp eq (X & Mask), 0
auto tryToDecompose = [](ICmpInst *ICmp, Value *&X,
APInt &UnsetBitsMask) -> bool {
CmpPredicate Pred = ICmp->getPredicate();
// Can it be decomposed into icmp eq (X & Mask), 0 ?
auto Res =
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
Pred, /*LookThroughTrunc=*/false);
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
X = Res->X;
UnsetBitsMask = Res->Mask;
return true;
}
// Is it icmp eq (X & Mask), 0 already?
const APInt *Mask;
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
Pred == ICmpInst::ICMP_EQ) {
UnsetBitsMask = *Mask;
return true;
}
return false;
};

auto MatchVariableBitMask = [&]() {
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
match(CmpLHS,
m_c_And(m_Value(CurrX),
m_CombineAnd(
m_Value(BitMask),
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
CurLoop))));
};
auto MatchConstantBitMask = [&]() {
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
match(CmpLHS, m_And(m_Value(CurrX),
m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
};
auto MatchDecomposableConstantBitMask = [&]() {
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
if (Res && Res->Mask.isPowerOf2()) {
assert(ICmpInst::isEquality(Res->Pred));
Pred = Res->Pred;
CurrX = Res->X;
BitMask = ConstantInt::get(CurrX->getType(), Res->Mask);
BitPos = ConstantInt::get(CurrX->getType(), Res->Mask.logBase2());
return true;
}
return false;
};
if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
!MatchDecomposableConstantBitMask()) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
return false;
}

if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
if (ICmpInst::isEquality(Pred)) {
if (!match(CmpRHS, m_Zero()))
return nullptr;
V = CmpLHS;
const APInt *AndRHS;
if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
return nullptr;
AndMask = *AndRHS;
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
AndMask = Res->Mask;
V = Res->X;
KnownBits Known =
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
AndMask &= Known.getMaxValue();
if (!AndMask.isPowerOf2())
return nullptr;
Pred = Res->Pred;
CreateAnd = true;
} else {
return nullptr;
}
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
V = Trunc->getOperand(0);
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
Pred = ICmpInst::ICMP_NE;
CreateAnd = !Trunc->hasNoUnsignedWrap();
} else {
return nullptr;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I am still confused by the parameter name "DecomposeBitMask". It is just what this helper function does.
I think "MatchCanonicalForm" would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function decomposes and finds better forms for comparisons by testing for bit conditions of the compared value -> "decomposeBitTest"

The new capability decomposes cases where we bitmask a value via and -> "decomposeBitMask"

Does this explanation help, or is it still too confusing?

Either way, "MatchCanonicalForm" seems to be much more vague and isn't really describing what the capability does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added #136367 -- we can discuss there.

CmpPredicate Pred0, Pred1;

auto LHSDecompose =
decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andjo403 Do you think it is overkill to use decomposeBitTest? The only thing I want to match are icmp ne/eq (and X, Power2), 0 and trunc X to i1. It seems that the mask of decomposeBitTest(relational icmp) is always non-power-of-2 :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants