-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[InstCombine] Combine and->cmp->sel->or-disjoint into and->mul #135274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Change-Id: Id45315f1e5f71077800d3a8141b85bb3b5d8f38a
%tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true> | ||
%tmp3 = or <16 x i1> %tmp, %tmp2 | ||
ret <16 x i1> %tmp3 | ||
%temp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To silence a update_test_checks warning
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Jeffrey Byrnes (jrbyrnes) ChangesWhile and->cmp->set combines into and->mul may result in worse code on some targets, this combine should be uniformly beneficial. https://alive2.llvm.org/ce/z/3Dnw2u Full diff: https://github.com/llvm/llvm-project/pull/135274.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc241781d112..6dc4b97686f97 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3643,6 +3643,48 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
/*NSW=*/true, /*NUW=*/true))
return R;
+
+ Value *Cond0 = nullptr, *Cond1 = nullptr;
+ ConstantInt *Op0True = nullptr, *Op0False = nullptr;
+ ConstantInt *Op1True = nullptr, *Op1False = nullptr;
+
+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+ if (match(I.getOperand(0), m_Select(m_Value(Cond0), m_ConstantInt(Op0True),
+ m_ConstantInt(Op0False))) &&
+ match(I.getOperand(1), m_Select(m_Value(Cond1), m_ConstantInt(Op1True),
+ m_ConstantInt(Op1False))) &&
+ Op0True->isZero() && Op1True->isZero() &&
+ Op0False->getValue().tryZExtValue() &&
+ Op1False->getValue().tryZExtValue()) {
+ CmpPredicate Pred0, Pred1;
+ Value *CmpOp0 = nullptr, *CmpOp1 = nullptr;
+ ConstantInt *Op0Cond = nullptr, *Op1Cond = nullptr;
+ if (match(Cond0,
+ m_c_ICmp(Pred0, m_Value(CmpOp0), m_ConstantInt(Op0Cond))) &&
+ match(Cond1,
+ m_c_ICmp(Pred1, m_Value(CmpOp1), m_ConstantInt(Op1Cond))) &&
+ Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ &&
+ Op0Cond->isZero() && Op1Cond->isZero()) {
+ Value *AndSrc0 = nullptr, *AndSrc1 = nullptr;
+ ConstantInt *BitSel0 = nullptr, *BitSel1 = nullptr;
+ if (match(CmpOp0, m_And(m_Value(AndSrc0), m_ConstantInt(BitSel0))) &&
+ match(CmpOp1, m_And(m_Value(AndSrc1), m_ConstantInt(BitSel1))) &&
+ AndSrc0 == AndSrc1 && BitSel0->getValue().tryZExtValue() &&
+ BitSel1->getValue().tryZExtValue()) {
+ unsigned Out0 = Op0False->getValue().getZExtValue();
+ unsigned Out1 = Op1False->getValue().getZExtValue();
+ unsigned Sel0 = BitSel0->getValue().getZExtValue();
+ unsigned Sel1 = BitSel1->getValue().getZExtValue();
+ if (!(Out0 % Sel0) && !(Out1 % Sel1) &&
+ ((Out0 / Sel0) == (Out1 / Sel1))) {
+ auto NewAnd = Builder.CreateAnd(
+ AndSrc0, ConstantInt::get(AndSrc0->getType(), Sel0 + Sel1));
+ return BinaryOperator::CreateMul(
+ NewAnd, ConstantInt::get(NewAnd->getType(), (Out1 / Sel1)));
+ }
+ }
+ }
+ }
}
Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 95f89e4ce11cd..f2b21ca966592 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1281,10 +1281,10 @@ define <16 x i1> @test51(<16 x i1> %arg, <16 x i1> %arg1) {
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x i1> [[ARG:%.*]], <16 x i1> [[ARG1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 5, i32 6, i32 23, i32 24, i32 9, i32 10, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: ret <16 x i1> [[TMP3]]
;
- %tmp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
- %tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
- %tmp3 = or <16 x i1> %tmp, %tmp2
- ret <16 x i1> %tmp3
+ %temp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
+ %temp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
+ %temp3 = or <16 x i1> %temp, %temp2
+ ret <16 x i1> %temp3
}
; This would infinite loop because it reaches a transform
@@ -2035,3 +2035,109 @@ define i32 @or_xor_and_commuted3(i32 %x, i32 %y, i32 %z) {
%or1 = or i32 %xor, %yy
ret i32 %or1
}
+
+define i32 @add_select_cmp_and1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and1(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and2(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 5
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 4
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 288
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and3(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and3(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
+; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %temp = or disjoint i32 %sel0, %sel1
+ %bitop2 = and i32 %in, 4
+ %cmp2 = icmp eq i32 %bitop2, 0
+ %sel2 = select i1 %cmp2, i32 0, i32 288
+ %out = or disjoint i32 %temp, %sel2
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and4(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and4(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[TEMP2]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %temp = or disjoint i32 %sel0, %sel1
+ %bitop2 = and i32 %in, 4
+ %cmp2 = icmp eq i32 %bitop2, 0
+ %bitop3 = and i32 %in, 8
+ %cmp3 = icmp eq i32 %bitop3, 0
+ %sel2 = select i1 %cmp2, i32 0, i32 288
+ %sel3 = select i1 %cmp3, i32 0, i32 576
+ %temp2 = or disjoint i32 %sel2, %sel3
+ %out = or disjoint i32 %temp, %temp2
+ ret i32 %out
+}
+
+
+
+define i32 @add_select_cmp_and_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mismatch(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 3
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 288
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 3
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 288
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
|
Links for the other changed tests: https://alive2.llvm.org/ce/z/cDSsrr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use fixed-length integers.
Can you please provide a generalized alive2 proof?
Change-Id: I630d506375b0eb4b16dad1437bff2da357be2059
Generalized proof: https://alive2.llvm.org/ce/z/MibAcN |
@@ -96,10 +97,27 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, | |||
Pred = ICmpInst::getStrictPredicate(Pred); | |||
} | |||
|
|||
auto decomposeBitMask = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused why the original code doesn't handle this pattern.
cc @andjo403
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure -- but maybe it has something to do with the thought that and
->icmp
->sel
is the canonical sequence. Thus, we never thought to break the icmp
into composite parts for and
.
Change-Id: I24786ee6dc53a33fc7afbd80d226cda4e4a4df03
bool LookThroughTrunc = true, | ||
bool AllowNonZeroC = false); | ||
bool LookThroughTrunc = true, bool AllowNonZeroC = false, | ||
bool LookThruBitSel = false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think that the name LookThruBitSel is a bit confusing what is it looking through? maybe something with it matches the bit test pattern also, like allowBitTest or something. but also good to describe it in the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you prefer DecomposeBitMask ?
Change-Id: I897641e07a0e035c106f8d0d018f8a1b7a92aa6b
Change-Id: I2a995aa22358d9d532a7f839dd64d8c8633990fa
Change-Id: I769481b102dc943039b2bd07a815f2716e58fed4
Change-Id: I74ff3a7a1f7eebef1e9194a40c0c49834dd23117
✅ With the latest revision this PR passed the C/C++ code formatter. |
Change-Id: I360e4c6747cf445954143a70e7f8a57ace011dd8
CmpInst::Predicate Pred; | ||
APInt Mask; | ||
APInt C; | ||
}; | ||
|
||
/// Decompose an icmp into the form ((X & Mask) pred C) if possible. | ||
/// Unless \p AllowNonZeroC is true, C will always be 0. | ||
/// Unless \p AllowNonZeroC is true, C will always be 0. If \p | ||
/// DecomposeBitMask is specified, then, for equality predicates, this will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend landing this change separately. Then we can remove duplicate logic in other places.
llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Lines 873 to 895 in 3ed8363
// Try to match/decompose into: icmp eq (X & Mask), 0 | |
auto tryToDecompose = [](ICmpInst *ICmp, Value *&X, | |
APInt &UnsetBitsMask) -> bool { | |
CmpPredicate Pred = ICmp->getPredicate(); | |
// Can it be decomposed into icmp eq (X & Mask), 0 ? | |
auto Res = | |
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), | |
Pred, /*LookThroughTrunc=*/false); | |
if (Res && Res->Pred == ICmpInst::ICMP_EQ) { | |
X = Res->X; | |
UnsetBitsMask = Res->Mask; | |
return true; | |
} | |
// Is it icmp eq (X & Mask), 0 already? | |
const APInt *Mask; | |
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && | |
Pred == ICmpInst::ICMP_EQ) { | |
UnsetBitsMask = *Mask; | |
return true; | |
} | |
return false; | |
}; |
llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Lines 2755 to 2787 in 3ed8363
auto MatchVariableBitMask = [&]() { | |
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) && | |
match(CmpLHS, | |
m_c_And(m_Value(CurrX), | |
m_CombineAnd( | |
m_Value(BitMask), | |
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)), | |
CurLoop)))); | |
}; | |
auto MatchConstantBitMask = [&]() { | |
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) && | |
match(CmpLHS, m_And(m_Value(CurrX), | |
m_CombineAnd(m_Value(BitMask), m_Power2()))) && | |
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask))); | |
}; | |
auto MatchDecomposableConstantBitMask = [&]() { | |
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred); | |
if (Res && Res->Mask.isPowerOf2()) { | |
assert(ICmpInst::isEquality(Res->Pred)); | |
Pred = Res->Pred; | |
CurrX = Res->X; | |
BitMask = ConstantInt::get(CurrX->getType(), Res->Mask); | |
BitPos = ConstantInt::get(CurrX->getType(), Res->Mask.logBase2()); | |
return true; | |
} | |
return false; | |
}; | |
if (!MatchVariableBitMask() && !MatchConstantBitMask() && | |
!MatchDecomposableConstantBitMask()) { | |
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n"); | |
return false; | |
} |
llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Lines 3732 to 3765 in 3ed8363
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { | |
if (ICmpInst::isEquality(Pred)) { | |
if (!match(CmpRHS, m_Zero())) | |
return nullptr; | |
V = CmpLHS; | |
const APInt *AndRHS; | |
if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS)))) | |
return nullptr; | |
AndMask = *AndRHS; | |
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) { | |
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?"); | |
AndMask = Res->Mask; | |
V = Res->X; | |
KnownBits Known = | |
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel)); | |
AndMask &= Known.getMaxValue(); | |
if (!AndMask.isPowerOf2()) | |
return nullptr; | |
Pred = Res->Pred; | |
CreateAnd = true; | |
} else { | |
return nullptr; | |
} | |
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) { | |
V = Trunc->getOperand(0); | |
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1); | |
Pred = ICmpInst::ICMP_NE; | |
CreateAnd = !Trunc->hasNoUnsignedWrap(); | |
} else { | |
return nullptr; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I am still confused by the parameter name "DecomposeBitMask". It is just what this helper function does.
I think "MatchCanonicalForm" would be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function decomposes and finds better forms for comparisons by testing for bit conditions of the compared value -> "decomposeBitTest"
The new capability decomposes cases where we bitmask a value via and
-> "decomposeBitMask"
Does this explanation help, or is it still too confusing?
Either way, "MatchCanonicalForm" seems to be much more vague and isn't really describing what the capability does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added #136367 -- we can discuss there.
CmpPredicate Pred0, Pred1; | ||
|
||
auto LHSDecompose = | ||
decomposeBitTest(Cond0, /*LookThruTrunc=*/true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andjo403 Do you think it is overkill to use decomposeBitTest
? The only thing I want to match are icmp ne/eq (and X, Power2), 0
and trunc X to i1
. It seems that the mask of decomposeBitTest(relational icmp)
is always non-power-of-2 :(
While and->cmp->set combines into and->mul may result in worse code on some targets, this combine should be uniformly beneficial.
https://alive2.llvm.org/ce/z/3Dnw2u