Skip to content

[CmpInstAnalysis] Decompose icmp eq (and x, C) C2 #136367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jrbyrnes
Copy link
Contributor

This type of decomposition is used in multiple places already. Adding it to CmpInstAnalysis reduces code duplication.

Change-Id: I1dd786a4652ccd2e486db6903e16e58ffa1a7959
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jeffrey Byrnes (jrbyrnes)

Changes

This type of decomposition is used in multiple places already. Adding it to CmpInstAnalysis reduces code duplication.


Full diff: https://github.com/llvm/llvm-project/pull/136367.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+9-5)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+26-8)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+4-10)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+15-17)
  • (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+5-9)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X;
+    Value *X = nullptr;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
   };
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
-  /// Unless \p AllowNonZeroC is true, C will always be 0.
+  /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true,
-                       bool AllowNonZeroC = false);
+                       bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+                       bool DecomposeBitMask = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false);
+                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc, bool AllowNonZeroC) {
+                           bool LookThruTrunc, bool AllowNonZeroC,
+                           bool DecomposeBitMask) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+      !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
   bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
+
   switch (Pred) {
   default:
-    llvm_unreachable("Unexpected predicate");
+    return std::nullopt;
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (C.isPowerOf2()) {
       Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE: {
+    assert(DecomposeBitMask);
+    const APInt *AndC;
+    Value *AndVal;
+    if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      break;
+    }
+
+    return std::nullopt;
+  }
+  }
 
   if (!AllowNonZeroC && !Result.C.isZero())
     return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else {
+  } else if (!Result.X) {
     Result.X = LHS;
   }
 
   return Result;
 }
 
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+                                                        bool LookThruTrunc,
+                                                        bool AllowNonZeroC,
+                                                        bool DecomposeBitMask) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC);
+                                AllowNonZeroC, DecomposeBitMask);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    auto Res =
-        llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, /*LookThroughTrunc=*/false);
+    auto Res = llvm::decomposeBitTestICmp(
+        ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*DecomposeBitMask=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
       return true;
     }
 
-    // Is it  icmp eq (X & Mask), 0  already?
-    const APInt *Mask;
-    if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
-        Pred == ICmpInst::ICMP_EQ) {
-      UnsetBitsMask = *Mask;
-      return true;
-    }
     return false;
   };
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
+    auto Res = decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
 
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
+    if (!Res)
+      return nullptr;
 
-      AndMask = *AndRHS;
-    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-      AndMask = Res->Mask;
+    V = CmpLHS;
+    AndMask = Res->Mask;
+
+    if (!ICmpInst::isEquality(Pred)) {
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      if (!AndMask.isPowerOf2())
-        return nullptr;
-
-      Pred = Res->Pred;
       CreateAnd = true;
-    } else {
-      return nullptr;
     }
+
+    Pred = Res->Pred;
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
+
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
                              m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
                                              CurLoop))));
   };
-  auto MatchConstantBitMask = [&]() {
-    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
-           match(CmpLHS, m_And(m_Value(CurrX),
-                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
-           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
-  };
+
   auto MatchDecomposableConstantBitMask = [&]() {
-    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    auto Res = llvm::decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
     return false;
   };
 
-  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
-      !MatchDecomposableConstantBitMask()) {
+  if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
     LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
     return false;
   }

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Jeffrey Byrnes (jrbyrnes)

Changes

This type of decomposition is used in multiple places already. Adding it to CmpInstAnalysis reduces code duplication.


Full diff: https://github.com/llvm/llvm-project/pull/136367.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+9-5)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+26-8)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+4-10)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+15-17)
  • (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+5-9)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
   /// Represents the operation icmp (X & Mask) pred C, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
-    Value *X;
+    Value *X = nullptr;
     CmpInst::Predicate Pred;
     APInt Mask;
     APInt C;
   };
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if possible.
-  /// Unless \p AllowNonZeroC is true, C will always be 0.
+  /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+  /// DecomposeBitMask is specified, then, for equality predicates, this will
+  /// decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true,
-                       bool AllowNonZeroC = false);
+                       bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+                       bool DecomposeBitMask = false);
 
   /// Decompose an icmp into the form ((X & Mask) pred C) if
   /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+  /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+  /// will decompose bitmasking (e.g. implemented via `and`).
   std::optional<DecomposedBitTest>
   decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
-                   bool AllowNonZeroC = false);
+                   bool AllowNonZeroC = false, bool DecomposeBitMask = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc, bool AllowNonZeroC) {
+                           bool LookThruTrunc, bool AllowNonZeroC,
+                           bool DecomposeBitMask) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
-  if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+  if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+      !match(RHS, m_APIntAllowPoison(OrigC)))
     return std::nullopt;
 
   bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   }
 
   DecomposedBitTest Result;
+
   switch (Pred) {
   default:
-    llvm_unreachable("Unexpected predicate");
+    return std::nullopt;
   case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (C.isPowerOf2()) {
       Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
 
     return std::nullopt;
   }
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE: {
+    assert(DecomposeBitMask);
+    const APInt *AndC;
+    Value *AndVal;
+    if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+      Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+      break;
+    }
+
+    return std::nullopt;
+  }
+  }
 
   if (!AllowNonZeroC && !Result.C.isZero())
     return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
     Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
-  } else {
+  } else if (!Result.X) {
     Result.X = LHS;
   }
 
   return Result;
 }
 
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+                                                        bool LookThruTrunc,
+                                                        bool AllowNonZeroC,
+                                                        bool DecomposeBitMask) {
   using namespace PatternMatch;
   if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
     // Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
       return std::nullopt;
     return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC);
+                                AllowNonZeroC, DecomposeBitMask);
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    auto Res =
-        llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, /*LookThroughTrunc=*/false);
+    auto Res = llvm::decomposeBitTestICmp(
+        ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+        /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+        /*DecomposeBitMask=*/true);
     if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
       X = Res->X;
       UnsetBitsMask = Res->Mask;
       return true;
     }
 
-    // Is it  icmp eq (X & Mask), 0  already?
-    const APInt *Mask;
-    if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
-        Pred == ICmpInst::ICMP_EQ) {
-      UnsetBitsMask = *Mask;
-      return true;
-    }
     return false;
   };
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
   Value *CmpLHS, *CmpRHS;
 
   if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
-    if (ICmpInst::isEquality(Pred)) {
-      if (!match(CmpRHS, m_Zero()))
-        return nullptr;
+    auto Res = decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
 
-      V = CmpLHS;
-      const APInt *AndRHS;
-      if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
-        return nullptr;
+    if (!Res)
+      return nullptr;
 
-      AndMask = *AndRHS;
-    } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
-      assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
-      AndMask = Res->Mask;
+    V = CmpLHS;
+    AndMask = Res->Mask;
+
+    if (!ICmpInst::isEquality(Pred)) {
       V = Res->X;
       KnownBits Known =
           computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
       AndMask &= Known.getMaxValue();
-      if (!AndMask.isPowerOf2())
-        return nullptr;
-
-      Pred = Res->Pred;
       CreateAnd = true;
-    } else {
-      return nullptr;
     }
+
+    Pred = Res->Pred;
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
+
   } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
     V = Trunc->getOperand(0);
     AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
                              m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
                                              CurLoop))));
   };
-  auto MatchConstantBitMask = [&]() {
-    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
-           match(CmpLHS, m_And(m_Value(CurrX),
-                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
-           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
-  };
+
   auto MatchDecomposableConstantBitMask = [&]() {
-    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    auto Res = llvm::decomposeBitTestICmp(
+        CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+        /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
     if (Res && Res->Mask.isPowerOf2()) {
       assert(ICmpInst::isEquality(Res->Pred));
       Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
     return false;
   };
 
-  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
-      !MatchDecomposableConstantBitMask()) {
+  if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
     LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
     return false;
   }

@jrbyrnes
Copy link
Contributor Author

It's NFC for the current users.

bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
bool LookThroughTrunc = true, bool AllowNonZeroC = false,
bool DecomposeBitMask = false);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is concern about naming of DecomposeBitMask -- see #135274

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants