-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[ValueTracking] Support horizontal vector add in computeKnownBits #174410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ValueTracking] Support horizontal vector add in computeKnownBits #174410
Conversation
|
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-llvm-transforms Author: Valeriy Savchenko (SavchenkoValeriy) ChangesAlive2 proofs: Full diff: https://github.com/llvm/llvm-project/pull/174410.diff 3 Files Affected:
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9cb6f19b9340c..9a68df8d2b028 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2132,6 +2132,44 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known.One.clearAllBits();
break;
}
+ case Intrinsic::vector_reduce_add: {
+ auto *VecTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
+ if (!VecTy)
+ break;
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ unsigned NumElems = VecTy->getNumElements();
+ // The main idea is as follows.
+ //
+ // If KnownBits for vector.reduce.add has L leading zeros then
+ // X_i < 2^(W - L) for every i from [1, N].
+ //
+ // ADD X_i <= ADD max(X_i) = N * max(X_i)
+ // < N * 2^(W - L)
+ // < 2^(W - L + ceil(log2(N)))
+ //
+ // As the result, we can conclude that
+ //
+ // L' = L - ceil(log2(N)) = L - bit_width(N - 1)
+ //
+ // Similar logic can be applied to leading ones.
+ unsigned LostBits = NumElems > 1 ? llvm::bit_width(NumElems - 1) : 0;
+ if (Known.isNonNegative()) {
+ unsigned LeadingZeros = Known.countMinLeadingZeros();
+ LeadingZeros = LeadingZeros > LostBits ? LeadingZeros - LostBits : 0;
+ Known.Zero.clearAllBits();
+ Known.Zero.setHighBits(LeadingZeros);
+ Known.One.clearAllBits();
+ } else if (Known.isNegative()) {
+ unsigned LeadingOnes = Known.countMinLeadingOnes();
+ LeadingOnes = LeadingOnes > LostBits ? LeadingOnes - LostBits : 0;
+ Known.One.clearAllBits();
+ Known.One.setHighBits(LeadingOnes);
+ Known.Zero.clearAllBits();
+ } else {
+ Known.resetAll();
+ }
+ break;
+ }
case Intrinsic::umin:
computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
diff --git a/llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll b/llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll
new file mode 100644
index 0000000000000..60b898b492063
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/vector-reduce-add-known-bits.ll
@@ -0,0 +1,45 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i32 @reduce_add_eliminate_mask(ptr %p) {
+; CHECK-LABEL: define i32 @reduce_add_eliminate_mask(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT: [[VEC:%.*]] = load <4 x i32>, ptr [[P]], align 16
+; CHECK-NEXT: [[AND:%.*]] = and <4 x i32> [[VEC]], splat (i32 268435455)
+; CHECK-NEXT: [[SUM:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[AND]])
+; CHECK-NEXT: ret i32 [[SUM]]
+;
+ %vec = load <4 x i32>, ptr %p
+ %and = and <4 x i32> %vec, splat (i32 268435455)
+ %sum = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %and)
+ %masked = and i32 %sum, 1073741823
+ ret i32 %masked
+}
+
+define i1 @reduce_add_simplify_comparison(ptr %p) {
+; CHECK-LABEL: define i1 @reduce_add_simplify_comparison(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT: ret i1 true
+;
+ %vec = load <8 x i32>, ptr %p
+ %and = and <8 x i32> %vec, splat (i32 16777215)
+ %sum = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %and)
+ %cmp = icmp ult i32 %sum, 134217728
+ ret i1 %cmp
+}
+
+define i64 @reduce_add_sext(ptr %p) {
+; CHECK-LABEL: define i64 @reduce_add_sext(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT: [[VEC:%.*]] = load <2 x i32>, ptr [[P]], align 8
+; CHECK-NEXT: [[AND:%.*]] = and <2 x i32> [[VEC]], splat (i32 4194303)
+; CHECK-NEXT: [[SUM:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[AND]])
+; CHECK-NEXT: [[EXT:%.*]] = zext nneg i32 [[SUM]] to i64
+; CHECK-NEXT: ret i64 [[EXT]]
+;
+ %vec = load <2 x i32>, ptr %p
+ %and = and <2 x i32> %vec, splat (i32 4194303)
+ %sum = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %and)
+ %ext = sext i32 %sum to i64
+ ret i64 %ext
+}
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
index 4c7e39d31b5c6..e2f7f8f7e5cac 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
@@ -29,7 +29,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-O3-NEXT: [[TMP13:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP12]], i1 false)
; CHECK-O3-NEXT: [[TMP14:%.*]] = zext <16 x i16> [[TMP13]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP15:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
-; CHECK-O3-NEXT: [[OP_RDX_1:%.*]] = add i32 [[TMP15]], [[TMP7]]
+; CHECK-O3-NEXT: [[OP_RDX_1:%.*]] = add nuw nsw i32 [[TMP15]], [[TMP7]]
; CHECK-O3-NEXT: [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
; CHECK-O3-NEXT: [[TMP16:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -40,7 +40,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-O3-NEXT: [[TMP21:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP20]], i1 false)
; CHECK-O3-NEXT: [[TMP22:%.*]] = zext <16 x i16> [[TMP21]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP23:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP22]])
-; CHECK-O3-NEXT: [[OP_RDX_2:%.*]] = add i32 [[TMP23]], [[OP_RDX_1]]
+; CHECK-O3-NEXT: [[OP_RDX_2:%.*]] = add nuw nsw i32 [[TMP23]], [[OP_RDX_1]]
; CHECK-O3-NEXT: [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
; CHECK-O3-NEXT: [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -51,7 +51,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-O3-NEXT: [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 false)
; CHECK-O3-NEXT: [[TMP30:%.*]] = zext <16 x i16> [[TMP29]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP31:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP30]])
-; CHECK-O3-NEXT: [[OP_RDX_3:%.*]] = add i32 [[TMP31]], [[OP_RDX_2]]
+; CHECK-O3-NEXT: [[OP_RDX_3:%.*]] = add nuw nsw i32 [[TMP31]], [[OP_RDX_2]]
; CHECK-O3-NEXT: [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
; CHECK-O3-NEXT: [[TMP32:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -62,7 +62,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-O3-NEXT: [[TMP37:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP36]], i1 false)
; CHECK-O3-NEXT: [[TMP38:%.*]] = zext <16 x i16> [[TMP37]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP39:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP38]])
-; CHECK-O3-NEXT: [[OP_RDX_4:%.*]] = add i32 [[TMP39]], [[OP_RDX_3]]
+; CHECK-O3-NEXT: [[OP_RDX_4:%.*]] = add nuw nsw i32 [[TMP39]], [[OP_RDX_3]]
; CHECK-O3-NEXT: [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
; CHECK-O3-NEXT: [[TMP40:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -73,7 +73,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-O3-NEXT: [[TMP45:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP44]], i1 false)
; CHECK-O3-NEXT: [[TMP46:%.*]] = zext <16 x i16> [[TMP45]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP47:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP46]])
-; CHECK-O3-NEXT: [[OP_RDX_5:%.*]] = add i32 [[TMP47]], [[OP_RDX_4]]
+; CHECK-O3-NEXT: [[OP_RDX_5:%.*]] = add nuw nsw i32 [[TMP47]], [[OP_RDX_4]]
; CHECK-O3-NEXT: [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
; CHECK-O3-NEXT: [[TMP48:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -209,7 +209,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-LTO-NEXT: [[TMP11:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP10]], i1 true)
; CHECK-LTO-NEXT: [[TMP52:%.*]] = zext nneg <16 x i16> [[TMP11]] to <16 x i32>
; CHECK-LTO-NEXT: [[TMP60:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP52]])
-; CHECK-LTO-NEXT: [[OP_RDX_1:%.*]] = add i32 [[TMP60]], [[TMP44]]
+; CHECK-LTO-NEXT: [[OP_RDX_1:%.*]] = add nuw nsw i32 [[TMP60]], [[TMP44]]
; CHECK-LTO-NEXT: [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
; CHECK-LTO-NEXT: [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
; CHECK-LTO-NEXT: [[TMP12:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -220,7 +220,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-LTO-NEXT: [[TMP17:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP16]], i1 true)
; CHECK-LTO-NEXT: [[TMP68:%.*]] = zext nneg <16 x i16> [[TMP17]] to <16 x i32>
; CHECK-LTO-NEXT: [[TMP76:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP68]])
-; CHECK-LTO-NEXT: [[OP_RDX_2:%.*]] = add i32 [[OP_RDX_1]], [[TMP76]]
+; CHECK-LTO-NEXT: [[OP_RDX_2:%.*]] = add nuw nsw i32 [[OP_RDX_1]], [[TMP76]]
; CHECK-LTO-NEXT: [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
; CHECK-LTO-NEXT: [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
; CHECK-LTO-NEXT: [[TMP18:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -231,7 +231,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-LTO-NEXT: [[TMP23:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP22]], i1 true)
; CHECK-LTO-NEXT: [[TMP84:%.*]] = zext nneg <16 x i16> [[TMP23]] to <16 x i32>
; CHECK-LTO-NEXT: [[TMP92:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP84]])
-; CHECK-LTO-NEXT: [[OP_RDX_3:%.*]] = add i32 [[OP_RDX_2]], [[TMP92]]
+; CHECK-LTO-NEXT: [[OP_RDX_3:%.*]] = add nuw nsw i32 [[OP_RDX_2]], [[TMP92]]
; CHECK-LTO-NEXT: [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
; CHECK-LTO-NEXT: [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
; CHECK-LTO-NEXT: [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -242,7 +242,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-LTO-NEXT: [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 true)
; CHECK-LTO-NEXT: [[TMP100:%.*]] = zext nneg <16 x i16> [[TMP29]] to <16 x i32>
; CHECK-LTO-NEXT: [[TMP108:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP100]])
-; CHECK-LTO-NEXT: [[OP_RDX_4:%.*]] = add i32 [[OP_RDX_3]], [[TMP108]]
+; CHECK-LTO-NEXT: [[OP_RDX_4:%.*]] = add nuw nsw i32 [[OP_RDX_3]], [[TMP108]]
; CHECK-LTO-NEXT: [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
; CHECK-LTO-NEXT: [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
; CHECK-LTO-NEXT: [[TMP30:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -253,7 +253,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-LTO-NEXT: [[TMP35:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP34]], i1 true)
; CHECK-LTO-NEXT: [[TMP116:%.*]] = zext nneg <16 x i16> [[TMP35]] to <16 x i32>
; CHECK-LTO-NEXT: [[TMP117:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP116]])
-; CHECK-LTO-NEXT: [[OP_RDX_5:%.*]] = add i32 [[OP_RDX_4]], [[TMP117]]
+; CHECK-LTO-NEXT: [[OP_RDX_5:%.*]] = add nuw nsw i32 [[OP_RDX_4]], [[TMP117]]
; CHECK-LTO-NEXT: [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
; CHECK-LTO-NEXT: [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
; CHECK-LTO-NEXT: [[TMP37:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
|
|
Extracted from #173069 to be and independent change |
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to move this to KnownBits itself and add an exhaustive unit test in KnownBitsTest.cpp ?
At the moment, we don't support any reductions in KnownBits.cpp itself, so I'd say that it would make sense to move it if we move all of the other intrinsics there and that can get quite involved. In principle, I don't oppose that change, but if we want to do it just for the sake of better testing, maybe it can be a printer for known-bits that we can CHECK in lit tests instead? |
|
Moving non-trivial logic into KnownBits is generally a good idea, because that's how we reuse logic between the middle-end and back-end implementations of computeKnownBits(). (We could use the same logic for VECREDUCE_ADD.) |
So, moving just that bit of logic into KnownBits would be enough, right? And what about the signature for it. I guess it should be something like |
|
Maybe a class method? |
aac4edd to
c1528eb
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
866bf37 to
68bec97
Compare
68bec97 to
4524124
Compare
Alive2 proofs: