Skip to content

Commit 8837898

Browse files
v01dXYZv01dxyz
and
v01dxyz
authored
[DAGCombine] Count leading ones: refine post DAG/Type Legalisation if promotion (#102877)
This PR is related to #99591. In this PR, instead of modifying how the legalisation occurs depending on surrounding instructions, we refine after legalisation. This PR has two parts: * `SDPatternMatch/MatchContext`: Modify a little bit the code to match Operands (used by `m_Node(...)`) and Unary/Binary/Ternary Patterns to make it compatible with `VPMatchContext`, instead of only `m_Opc` supported. Some tests were added to ensure no regressions. * `DAGCombiner`: Add a `foldSubCtlzNot` which detect and rewrite the patterns using matching context. Remaining Tasks: - [ ] GlobalISel - [ ] Currently the pattern matching will occur even before legalisation. Should I restrict it to specific stages instead ? - [ ] Style: Add a visitVP_SUB ?? Move `foldSubCtlzNot` in another location for style consistency purpose ? @topperc --------- Co-authored-by: v01dxyz <[email protected]>
1 parent 5910e8d commit 8837898

File tree

8 files changed

+380
-24
lines changed

8 files changed

+380
-24
lines changed

Diff for: llvm/include/llvm/CodeGen/SDPatternMatch.h

+4
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,10 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) {
793793
return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op);
794794
}
795795

796+
template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) {
797+
return UnaryOpc_match<Opnd>(ISD::CTLZ, Op);
798+
}
799+
796800
// === Constants ===
797801
struct ConstantInt_match {
798802
APInt *BindVal;

Diff for: llvm/include/llvm/IR/VPIntrinsics.def

+1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ VP_PROPERTY_FUNCTIONAL_INTRINSIC(ctlz)
241241
VP_PROPERTY_FUNCTIONAL_SDOPC(CTLZ)
242242
END_REGISTER_VP_SDNODE(VP_CTLZ)
243243
BEGIN_REGISTER_VP_SDNODE(VP_CTLZ_ZERO_UNDEF, -1, vp_ctlz_zero_undef, 1, 2)
244+
VP_PROPERTY_FUNCTIONAL_SDOPC(CTLZ_ZERO_UNDEF)
244245
END_REGISTER_VP_SDNODE(VP_CTLZ_ZERO_UNDEF)
245246
END_REGISTER_VP_INTRINSIC(vp_ctlz)
246247

Diff for: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

+78
Original file line numberDiff line numberDiff line change
@@ -3764,6 +3764,79 @@ SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
37643764
return SDValue();
37653765
}
37663766

3767+
// Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
3768+
// counting leading ones. Broadly, it replaces the substraction with a left
3769+
// shift.
3770+
//
3771+
// * DAG Legalisation Pattern:
3772+
//
3773+
// (sub (ctlz (zeroextend (not Src)))
3774+
// BitWidthDiff)
3775+
//
3776+
// if BitWidthDiff == BitWidth(Node) - BitWidth(Src)
3777+
// -->
3778+
//
3779+
// (ctlz_zero_undef (not (shl (anyextend Src)
3780+
// BitWidthDiff)))
3781+
//
3782+
// * Type Legalisation Pattern:
3783+
//
3784+
// (sub (ctlz (and (xor Src XorMask)
3785+
// AndMask))
3786+
// BitWidthDiff)
3787+
//
3788+
// if AndMask has only trailing ones
3789+
// and MaskBitWidth(AndMask) == BitWidth(Node) - BitWidthDiff
3790+
// and XorMask has more trailing ones than AndMask
3791+
// -->
3792+
//
3793+
// (ctlz_zero_undef (not (shl Src BitWidthDiff)))
3794+
template <class MatchContextClass>
3795+
static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
3796+
const SDLoc DL(N);
3797+
SDValue N0 = N->getOperand(0);
3798+
EVT VT = N0.getValueType();
3799+
unsigned BitWidth = VT.getScalarSizeInBits();
3800+
3801+
MatchContextClass Matcher(DAG, DAG.getTargetLoweringInfo(), N);
3802+
3803+
APInt AndMask;
3804+
APInt XorMask;
3805+
APInt BitWidthDiff;
3806+
3807+
SDValue CtlzOp;
3808+
SDValue Src;
3809+
3810+
if (!sd_context_match(
3811+
N, Matcher, m_Sub(m_Ctlz(m_Value(CtlzOp)), m_ConstInt(BitWidthDiff))))
3812+
return SDValue();
3813+
3814+
if (sd_context_match(CtlzOp, Matcher, m_ZExt(m_Not(m_Value(Src))))) {
3815+
// DAG Legalisation Pattern:
3816+
// (sub (ctlz (zero_extend (not Op)) BitWidthDiff))
3817+
if ((BitWidth - Src.getValueType().getScalarSizeInBits()) != BitWidthDiff)
3818+
return SDValue();
3819+
3820+
Src = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Src);
3821+
} else if (sd_context_match(CtlzOp, Matcher,
3822+
m_And(m_Xor(m_Value(Src), m_ConstInt(XorMask)),
3823+
m_ConstInt(AndMask)))) {
3824+
// Type Legalisation Pattern:
3825+
// (sub (ctlz (and (xor Op XorMask) AndMask)) BitWidthDiff)
3826+
unsigned AndMaskWidth = BitWidth - BitWidthDiff.getZExtValue();
3827+
if (!(AndMask.isMask(AndMaskWidth) && XorMask.countr_one() >= AndMaskWidth))
3828+
return SDValue();
3829+
} else
3830+
return SDValue();
3831+
3832+
SDValue ShiftConst = DAG.getShiftAmountConstant(BitWidthDiff, VT, DL);
3833+
SDValue LShift = Matcher.getNode(ISD::SHL, DL, VT, Src, ShiftConst);
3834+
SDValue Not =
3835+
Matcher.getNode(ISD::XOR, DL, VT, LShift, DAG.getAllOnesConstant(DL, VT));
3836+
3837+
return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
3838+
}
3839+
37673840
// Since it may not be valid to emit a fold to zero for vector initializers
37683841
// check if we can before folding.
37693842
static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
@@ -3788,6 +3861,9 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
37883861
return N;
37893862
};
37903863

3864+
if (SDValue V = foldSubCtlzNot<EmptyMatchContext>(N, DAG))
3865+
return V;
3866+
37913867
// fold (sub x, x) -> 0
37923868
// FIXME: Refactor this and xor and other similar operations together.
37933869
if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
@@ -26989,6 +27065,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
2698927065
return visitVP_SELECT(N);
2699027066
case ISD::VP_MUL:
2699127067
return visitMUL<VPMatchContext>(N);
27068+
case ISD::VP_SUB:
27069+
return foldSubCtlzNot<VPMatchContext>(N, DAG);
2699227070
default:
2699327071
break;
2699427072
}

Diff for: llvm/test/CodeGen/AArch64/ctlo.ll

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s --mtriple=aarch64 -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK,CHECK-SD
3+
; RUN: llc < %s --mtriple=aarch64 -global-isel -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK,CHECK-GI
4+
5+
declare i8 @llvm.ctlz.i8(i8, i1)
6+
declare i16 @llvm.ctlz.i16(i16, i1)
7+
declare i32 @llvm.ctlz.i32(i32, i1)
8+
declare i64 @llvm.ctlz.i64(i64, i1)
9+
10+
define i8 @ctlo_i8(i8 %x) {
11+
; CHECK-SD-LABEL: ctlo_i8:
12+
; CHECK-SD: // %bb.0:
13+
; CHECK-SD-NEXT: mov w8, #-1 // =0xffffffff
14+
; CHECK-SD-NEXT: eor w8, w8, w0, lsl #24
15+
; CHECK-SD-NEXT: clz w0, w8
16+
; CHECK-SD-NEXT: ret
17+
;
18+
; CHECK-GI-LABEL: ctlo_i8:
19+
; CHECK-GI: // %bb.0:
20+
; CHECK-GI-NEXT: mov w8, #255 // =0xff
21+
; CHECK-GI-NEXT: bic w8, w8, w0
22+
; CHECK-GI-NEXT: clz w8, w8
23+
; CHECK-GI-NEXT: sub w0, w8, #24
24+
; CHECK-GI-NEXT: ret
25+
%tmp1 = xor i8 %x, -1
26+
%tmp2 = call i8 @llvm.ctlz.i8( i8 %tmp1, i1 false )
27+
ret i8 %tmp2
28+
}
29+
30+
define i8 @ctlo_i8_undef(i8 %x) {
31+
; CHECK-SD-LABEL: ctlo_i8_undef:
32+
; CHECK-SD: // %bb.0:
33+
; CHECK-SD-NEXT: mvn w8, w0
34+
; CHECK-SD-NEXT: lsl w8, w8, #24
35+
; CHECK-SD-NEXT: clz w0, w8
36+
; CHECK-SD-NEXT: ret
37+
;
38+
; CHECK-GI-LABEL: ctlo_i8_undef:
39+
; CHECK-GI: // %bb.0:
40+
; CHECK-GI-NEXT: mov w8, #255 // =0xff
41+
; CHECK-GI-NEXT: bic w8, w8, w0
42+
; CHECK-GI-NEXT: clz w8, w8
43+
; CHECK-GI-NEXT: sub w0, w8, #24
44+
; CHECK-GI-NEXT: ret
45+
%tmp1 = xor i8 %x, -1
46+
%tmp2 = call i8 @llvm.ctlz.i8( i8 %tmp1, i1 true )
47+
ret i8 %tmp2
48+
}
49+
50+
define i16 @ctlo_i16(i16 %x) {
51+
; CHECK-SD-LABEL: ctlo_i16:
52+
; CHECK-SD: // %bb.0:
53+
; CHECK-SD-NEXT: mov w8, #-1 // =0xffffffff
54+
; CHECK-SD-NEXT: eor w8, w8, w0, lsl #16
55+
; CHECK-SD-NEXT: clz w0, w8
56+
; CHECK-SD-NEXT: ret
57+
;
58+
; CHECK-GI-LABEL: ctlo_i16:
59+
; CHECK-GI: // %bb.0:
60+
; CHECK-GI-NEXT: mov w8, #65535 // =0xffff
61+
; CHECK-GI-NEXT: bic w8, w8, w0
62+
; CHECK-GI-NEXT: clz w8, w8
63+
; CHECK-GI-NEXT: sub w0, w8, #16
64+
; CHECK-GI-NEXT: ret
65+
%tmp1 = xor i16 %x, -1
66+
%tmp2 = call i16 @llvm.ctlz.i16( i16 %tmp1, i1 false )
67+
ret i16 %tmp2
68+
}
69+
70+
define i16 @ctlo_i16_undef(i16 %x) {
71+
; CHECK-SD-LABEL: ctlo_i16_undef:
72+
; CHECK-SD: // %bb.0:
73+
; CHECK-SD-NEXT: mvn w8, w0
74+
; CHECK-SD-NEXT: lsl w8, w8, #16
75+
; CHECK-SD-NEXT: clz w0, w8
76+
; CHECK-SD-NEXT: ret
77+
;
78+
; CHECK-GI-LABEL: ctlo_i16_undef:
79+
; CHECK-GI: // %bb.0:
80+
; CHECK-GI-NEXT: mov w8, #65535 // =0xffff
81+
; CHECK-GI-NEXT: bic w8, w8, w0
82+
; CHECK-GI-NEXT: clz w8, w8
83+
; CHECK-GI-NEXT: sub w0, w8, #16
84+
; CHECK-GI-NEXT: ret
85+
%tmp1 = xor i16 %x, -1
86+
%tmp2 = call i16 @llvm.ctlz.i16( i16 %tmp1, i1 true )
87+
ret i16 %tmp2
88+
}
89+
90+
define i32 @ctlo_i32(i32 %x) {
91+
; CHECK-LABEL: ctlo_i32:
92+
; CHECK: // %bb.0:
93+
; CHECK-NEXT: mvn w8, w0
94+
; CHECK-NEXT: clz w0, w8
95+
; CHECK-NEXT: ret
96+
%tmp1 = xor i32 %x, -1
97+
%tmp2 = call i32 @llvm.ctlz.i32( i32 %tmp1, i1 false )
98+
ret i32 %tmp2
99+
}
100+
101+
define i32 @ctlo_i32_undef(i32 %x) {
102+
; CHECK-LABEL: ctlo_i32_undef:
103+
; CHECK: // %bb.0:
104+
; CHECK-NEXT: mvn w8, w0
105+
; CHECK-NEXT: clz w0, w8
106+
; CHECK-NEXT: ret
107+
%tmp1 = xor i32 %x, -1
108+
%tmp2 = call i32 @llvm.ctlz.i32( i32 %tmp1, i1 true )
109+
ret i32 %tmp2
110+
}
111+
112+
define i64 @ctlo_i64(i64 %x) {
113+
; CHECK-LABEL: ctlo_i64:
114+
; CHECK: // %bb.0:
115+
; CHECK-NEXT: mvn x8, x0
116+
; CHECK-NEXT: clz x0, x8
117+
; CHECK-NEXT: ret
118+
%tmp1 = xor i64 %x, -1
119+
%tmp2 = call i64 @llvm.ctlz.i64( i64 %tmp1, i1 false )
120+
ret i64 %tmp2
121+
}
122+
123+
define i64 @ctlo_i64_undef(i64 %x) {
124+
; CHECK-LABEL: ctlo_i64_undef:
125+
; CHECK: // %bb.0:
126+
; CHECK-NEXT: mvn x8, x0
127+
; CHECK-NEXT: clz x0, x8
128+
; CHECK-NEXT: ret
129+
%tmp1 = xor i64 %x, -1
130+
%tmp2 = call i64 @llvm.ctlz.i64( i64 %tmp1, i1 true )
131+
ret i64 %tmp2
132+
}

Diff for: llvm/test/CodeGen/LoongArch/ctlz-cttz-ctpop.ll

+8-16
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,14 @@ define i64 @test_ctlz_i64(i64 %a) nounwind {
8989
define i8 @test_not_ctlz_i8(i8 %a) nounwind {
9090
; LA32-LABEL: test_not_ctlz_i8:
9191
; LA32: # %bb.0:
92-
; LA32-NEXT: ori $a1, $zero, 255
93-
; LA32-NEXT: andn $a0, $a1, $a0
94-
; LA32-NEXT: clz.w $a0, $a0
95-
; LA32-NEXT: addi.w $a0, $a0, -24
92+
; LA32-NEXT: slli.w $a0, $a0, 24
93+
; LA32-NEXT: clo.w $a0, $a0
9694
; LA32-NEXT: ret
9795
;
9896
; LA64-LABEL: test_not_ctlz_i8:
9997
; LA64: # %bb.0:
100-
; LA64-NEXT: ori $a1, $zero, 255
101-
; LA64-NEXT: andn $a0, $a1, $a0
102-
; LA64-NEXT: clz.d $a0, $a0
103-
; LA64-NEXT: addi.d $a0, $a0, -56
98+
; LA64-NEXT: slli.d $a0, $a0, 56
99+
; LA64-NEXT: clo.d $a0, $a0
104100
; LA64-NEXT: ret
105101
%neg = xor i8 %a, -1
106102
%tmp = call i8 @llvm.ctlz.i8(i8 %neg, i1 false)
@@ -110,18 +106,14 @@ define i8 @test_not_ctlz_i8(i8 %a) nounwind {
110106
define i16 @test_not_ctlz_i16(i16 %a) nounwind {
111107
; LA32-LABEL: test_not_ctlz_i16:
112108
; LA32: # %bb.0:
113-
; LA32-NEXT: nor $a0, $a0, $zero
114-
; LA32-NEXT: bstrpick.w $a0, $a0, 15, 0
115-
; LA32-NEXT: clz.w $a0, $a0
116-
; LA32-NEXT: addi.w $a0, $a0, -16
109+
; LA32-NEXT: slli.w $a0, $a0, 16
110+
; LA32-NEXT: clo.w $a0, $a0
117111
; LA32-NEXT: ret
118112
;
119113
; LA64-LABEL: test_not_ctlz_i16:
120114
; LA64: # %bb.0:
121-
; LA64-NEXT: nor $a0, $a0, $zero
122-
; LA64-NEXT: bstrpick.d $a0, $a0, 15, 0
123-
; LA64-NEXT: clz.d $a0, $a0
124-
; LA64-NEXT: addi.d $a0, $a0, -48
115+
; LA64-NEXT: slli.d $a0, $a0, 48
116+
; LA64-NEXT: clo.d $a0, $a0
125117
; LA64-NEXT: ret
126118
%neg = xor i16 %a, -1
127119
%tmp = call i16 @llvm.ctlz.i16(i16 %neg, i1 false)

0 commit comments

Comments
 (0)