Skip to content

Commit c1528eb

Browse files
[ValueTracking] Support horizontal vector add in computeKnownBits
1 parent 07b8aa8 commit c1528eb

File tree

6 files changed

+141
-10
lines changed

6 files changed

+141
-10
lines changed

llvm/include/llvm/Support/KnownBits.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,10 @@ struct KnownBits {
511511
/// Compute known bits for the absolute value.
512512
LLVM_ABI KnownBits abs(bool IntMinIsPoison = false) const;
513513

514+
/// Compute known bits for horizontal add for a vector with NumElts
515+
/// elements, where each element has the known bits represented by this object.
516+
LLVM_ABI KnownBits reduceAdd(unsigned NumElts) const;
517+
514518
KnownBits byteSwap() const {
515519
return KnownBits(Zero.byteSwap(), One.byteSwap());
516520
}

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
21322132
Known.One.clearAllBits();
21332133
break;
21342134
}
2135+
case Intrinsic::vector_reduce_add: {
2136+
auto *VecTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
2137+
if (!VecTy)
2138+
break;
2139+
computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
2140+
Known = Known.reduceAdd(VecTy->getNumElements());
2141+
break;
2142+
}
21352143
case Intrinsic::umin:
21362144
computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
21372145
computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);

llvm/lib/Support/KnownBits.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,46 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
601601
return KnownAbs;
602602
}
603603

604+
KnownBits KnownBits::reduceAdd(unsigned NumElts) const {
605+
if (NumElts == 0)
606+
return KnownBits(getBitWidth());
607+
608+
unsigned BitWidth = getBitWidth();
609+
KnownBits Result(BitWidth);
610+
611+
if (isConstant())
612+
// If all elements are the same constant, we can simply compute it
613+
return KnownBits::makeConstant(NumElts * getConstant());
614+
615+
// The main idea is as follows.
616+
//
617+
// If KnownBits for each element has L leading zeros then
618+
// X_i < 2^(W - L) for every i from [1, N].
619+
//
620+
// ADD X_i <= ADD max(X_i) = N * max(X_i)
621+
// < N * 2^(W - L)
622+
// < 2^(W - L + ceil(log2(N)))
623+
//
624+
// As the result, we can conclude that
625+
//
626+
// L' = L - ceil(log2(N)) = L - bit_width(N - 1)
627+
//
628+
// Similar logic can be applied to leading ones.
629+
unsigned LostBits = NumElts > 1 ? llvm::bit_width(NumElts - 1) : 0;
630+
631+
if (isNonNegative()) {
632+
unsigned LeadingZeros = countMinLeadingZeros();
633+
LeadingZeros = LeadingZeros > LostBits ? LeadingZeros - LostBits : 0;
634+
Result.Zero.setHighBits(LeadingZeros);
635+
} else if (isNegative()) {
636+
unsigned LeadingOnes = countMinLeadingOnes();
637+
LeadingOnes = LeadingOnes > LostBits ? LeadingOnes - LostBits : 0;
638+
Result.One.setHighBits(LeadingOnes);
639+
}
640+
641+
return Result;
642+
}
643+
604644
static KnownBits computeForSatAddSub(bool Add, bool Signed,
605645
const KnownBits &LHS,
606646
const KnownBits &RHS) {
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
define i32 @reduce_add_eliminate_mask(ptr %p) {
5+
; CHECK-LABEL: define i32 @reduce_add_eliminate_mask(
6+
; CHECK-SAME: ptr [[P:%.*]]) {
7+
; CHECK-NEXT: [[VEC:%.*]] = load <4 x i32>, ptr [[P]], align 16
8+
; CHECK-NEXT: [[AND:%.*]] = and <4 x i32> [[VEC]], splat (i32 268435455)
9+
; CHECK-NEXT: [[SUM:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[AND]])
10+
; CHECK-NEXT: ret i32 [[SUM]]
11+
;
12+
%vec = load <4 x i32>, ptr %p
13+
%and = and <4 x i32> %vec, splat (i32 268435455)
14+
%sum = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %and)
15+
%masked = and i32 %sum, 1073741823
16+
ret i32 %masked
17+
}
18+
19+
define i1 @reduce_add_simplify_comparison(ptr %p) {
20+
; CHECK-LABEL: define i1 @reduce_add_simplify_comparison(
21+
; CHECK-SAME: ptr [[P:%.*]]) {
22+
; CHECK-NEXT: ret i1 true
23+
;
24+
%vec = load <8 x i32>, ptr %p
25+
%and = and <8 x i32> %vec, splat (i32 16777215)
26+
%sum = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %and)
27+
%cmp = icmp ult i32 %sum, 134217728
28+
ret i1 %cmp
29+
}
30+
31+
define i64 @reduce_add_sext(ptr %p) {
32+
; CHECK-LABEL: define i64 @reduce_add_sext(
33+
; CHECK-SAME: ptr [[P:%.*]]) {
34+
; CHECK-NEXT: [[VEC:%.*]] = load <2 x i32>, ptr [[P]], align 8
35+
; CHECK-NEXT: [[AND:%.*]] = and <2 x i32> [[VEC]], splat (i32 4194303)
36+
; CHECK-NEXT: [[SUM:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[AND]])
37+
; CHECK-NEXT: [[EXT:%.*]] = zext nneg i32 [[SUM]] to i64
38+
; CHECK-NEXT: ret i64 [[EXT]]
39+
;
40+
%vec = load <2 x i32>, ptr %p
41+
%and = and <2 x i32> %vec, splat (i32 4194303)
42+
%sum = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %and)
43+
%ext = sext i32 %sum to i64
44+
ret i64 %ext
45+
}

llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
2929
; CHECK-O3-NEXT: [[TMP13:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP12]], i1 false)
3030
; CHECK-O3-NEXT: [[TMP14:%.*]] = zext <16 x i16> [[TMP13]] to <16 x i32>
3131
; CHECK-O3-NEXT: [[TMP15:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
32-
; CHECK-O3-NEXT: [[OP_RDX_1:%.*]] = add i32 [[TMP15]], [[TMP7]]
32+
; CHECK-O3-NEXT: [[OP_RDX_1:%.*]] = add nuw nsw i32 [[TMP15]], [[TMP7]]
3333
; CHECK-O3-NEXT: [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
3434
; CHECK-O3-NEXT: [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
3535
; CHECK-O3-NEXT: [[TMP16:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -40,7 +40,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
4040
; CHECK-O3-NEXT: [[TMP21:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP20]], i1 false)
4141
; CHECK-O3-NEXT: [[TMP22:%.*]] = zext <16 x i16> [[TMP21]] to <16 x i32>
4242
; CHECK-O3-NEXT: [[TMP23:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP22]])
43-
; CHECK-O3-NEXT: [[OP_RDX_2:%.*]] = add i32 [[TMP23]], [[OP_RDX_1]]
43+
; CHECK-O3-NEXT: [[OP_RDX_2:%.*]] = add nuw nsw i32 [[TMP23]], [[OP_RDX_1]]
4444
; CHECK-O3-NEXT: [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
4545
; CHECK-O3-NEXT: [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
4646
; CHECK-O3-NEXT: [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -51,7 +51,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
5151
; CHECK-O3-NEXT: [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 false)
5252
; CHECK-O3-NEXT: [[TMP30:%.*]] = zext <16 x i16> [[TMP29]] to <16 x i32>
5353
; CHECK-O3-NEXT: [[TMP31:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP30]])
54-
; CHECK-O3-NEXT: [[OP_RDX_3:%.*]] = add i32 [[TMP31]], [[OP_RDX_2]]
54+
; CHECK-O3-NEXT: [[OP_RDX_3:%.*]] = add nuw nsw i32 [[TMP31]], [[OP_RDX_2]]
5555
; CHECK-O3-NEXT: [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
5656
; CHECK-O3-NEXT: [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
5757
; CHECK-O3-NEXT: [[TMP32:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -62,7 +62,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
6262
; CHECK-O3-NEXT: [[TMP37:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP36]], i1 false)
6363
; CHECK-O3-NEXT: [[TMP38:%.*]] = zext <16 x i16> [[TMP37]] to <16 x i32>
6464
; CHECK-O3-NEXT: [[TMP39:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP38]])
65-
; CHECK-O3-NEXT: [[OP_RDX_4:%.*]] = add i32 [[TMP39]], [[OP_RDX_3]]
65+
; CHECK-O3-NEXT: [[OP_RDX_4:%.*]] = add nuw nsw i32 [[TMP39]], [[OP_RDX_3]]
6666
; CHECK-O3-NEXT: [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
6767
; CHECK-O3-NEXT: [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
6868
; CHECK-O3-NEXT: [[TMP40:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -73,7 +73,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
7373
; CHECK-O3-NEXT: [[TMP45:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP44]], i1 false)
7474
; CHECK-O3-NEXT: [[TMP46:%.*]] = zext <16 x i16> [[TMP45]] to <16 x i32>
7575
; CHECK-O3-NEXT: [[TMP47:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP46]])
76-
; CHECK-O3-NEXT: [[OP_RDX_5:%.*]] = add i32 [[TMP47]], [[OP_RDX_4]]
76+
; CHECK-O3-NEXT: [[OP_RDX_5:%.*]] = add nuw nsw i32 [[TMP47]], [[OP_RDX_4]]
7777
; CHECK-O3-NEXT: [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
7878
; CHECK-O3-NEXT: [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
7979
; CHECK-O3-NEXT: [[TMP48:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -209,7 +209,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
209209
; CHECK-LTO-NEXT: [[TMP11:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP10]], i1 true)
210210
; CHECK-LTO-NEXT: [[TMP52:%.*]] = zext nneg <16 x i16> [[TMP11]] to <16 x i32>
211211
; CHECK-LTO-NEXT: [[TMP60:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP52]])
212-
; CHECK-LTO-NEXT: [[OP_RDX_1:%.*]] = add i32 [[TMP60]], [[TMP44]]
212+
; CHECK-LTO-NEXT: [[OP_RDX_1:%.*]] = add nuw nsw i32 [[TMP60]], [[TMP44]]
213213
; CHECK-LTO-NEXT: [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
214214
; CHECK-LTO-NEXT: [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
215215
; CHECK-LTO-NEXT: [[TMP12:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -220,7 +220,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
220220
; CHECK-LTO-NEXT: [[TMP17:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP16]], i1 true)
221221
; CHECK-LTO-NEXT: [[TMP68:%.*]] = zext nneg <16 x i16> [[TMP17]] to <16 x i32>
222222
; CHECK-LTO-NEXT: [[TMP76:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP68]])
223-
; CHECK-LTO-NEXT: [[OP_RDX_2:%.*]] = add i32 [[OP_RDX_1]], [[TMP76]]
223+
; CHECK-LTO-NEXT: [[OP_RDX_2:%.*]] = add nuw nsw i32 [[OP_RDX_1]], [[TMP76]]
224224
; CHECK-LTO-NEXT: [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
225225
; CHECK-LTO-NEXT: [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
226226
; CHECK-LTO-NEXT: [[TMP18:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -231,7 +231,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
231231
; CHECK-LTO-NEXT: [[TMP23:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP22]], i1 true)
232232
; CHECK-LTO-NEXT: [[TMP84:%.*]] = zext nneg <16 x i16> [[TMP23]] to <16 x i32>
233233
; CHECK-LTO-NEXT: [[TMP92:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP84]])
234-
; CHECK-LTO-NEXT: [[OP_RDX_3:%.*]] = add i32 [[OP_RDX_2]], [[TMP92]]
234+
; CHECK-LTO-NEXT: [[OP_RDX_3:%.*]] = add nuw nsw i32 [[OP_RDX_2]], [[TMP92]]
235235
; CHECK-LTO-NEXT: [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
236236
; CHECK-LTO-NEXT: [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
237237
; CHECK-LTO-NEXT: [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -242,7 +242,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
242242
; CHECK-LTO-NEXT: [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 true)
243243
; CHECK-LTO-NEXT: [[TMP100:%.*]] = zext nneg <16 x i16> [[TMP29]] to <16 x i32>
244244
; CHECK-LTO-NEXT: [[TMP108:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP100]])
245-
; CHECK-LTO-NEXT: [[OP_RDX_4:%.*]] = add i32 [[OP_RDX_3]], [[TMP108]]
245+
; CHECK-LTO-NEXT: [[OP_RDX_4:%.*]] = add nuw nsw i32 [[OP_RDX_3]], [[TMP108]]
246246
; CHECK-LTO-NEXT: [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
247247
; CHECK-LTO-NEXT: [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
248248
; CHECK-LTO-NEXT: [[TMP30:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
@@ -253,7 +253,7 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
253253
; CHECK-LTO-NEXT: [[TMP35:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP34]], i1 true)
254254
; CHECK-LTO-NEXT: [[TMP116:%.*]] = zext nneg <16 x i16> [[TMP35]] to <16 x i32>
255255
; CHECK-LTO-NEXT: [[TMP117:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP116]])
256-
; CHECK-LTO-NEXT: [[OP_RDX_5:%.*]] = add i32 [[OP_RDX_4]], [[TMP117]]
256+
; CHECK-LTO-NEXT: [[OP_RDX_5:%.*]] = add nuw nsw i32 [[OP_RDX_4]], [[TMP117]]
257257
; CHECK-LTO-NEXT: [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
258258
; CHECK-LTO-NEXT: [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
259259
; CHECK-LTO-NEXT: [[TMP37:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,4 +845,38 @@ TEST(KnownBitsTest, MulExhaustive) {
845845
}
846846
}
847847

848+
TEST(KnownBitsTest, ReduceAddExhaustive) {
849+
unsigned Bits = 4;
850+
for (unsigned NumElts : {2, 4}) {
851+
ForeachKnownBits(Bits, [&](const KnownBits &EltKnown) {
852+
KnownBits Computed = EltKnown.reduceAdd(NumElts);
853+
KnownBits Exact(Bits);
854+
Exact.Zero.setAllBits();
855+
Exact.One.setAllBits();
856+
857+
llvm::function_ref<void(unsigned, APInt)> EnumerateCombinations;
858+
auto EnumerateCombinationsImpl = [&](unsigned Depth, APInt CurrentSum) {
859+
if (Depth == NumElts) {
860+
Exact.One &= CurrentSum;
861+
Exact.Zero &= ~CurrentSum;
862+
return;
863+
}
864+
ForeachNumInKnownBits(EltKnown, [&](const APInt &Elt) {
865+
EnumerateCombinations(Depth + 1, CurrentSum + Elt);
866+
});
867+
};
868+
EnumerateCombinations = EnumerateCombinationsImpl;
869+
870+
// Here we recursively generate NumElts unique elements matching known bits
871+
// and collect exact known bits for all possible combinations.
872+
EnumerateCombinations(0, APInt(Bits, 0));
873+
874+
if (!Exact.hasConflict()) {
875+
EXPECT_TRUE(checkResult("reduceAdd", Exact, Computed, {EltKnown},
876+
/*CheckOptimality=*/false));
877+
}
878+
});
879+
}
880+
}
881+
848882
} // end anonymous namespace

0 commit comments

Comments
 (0)