Skip to content

Commit 66454fe

Browse files
committed
[GlobalISel] Add constant matcher for APInt
1 parent c2548a8 commit 66454fe

File tree

4 files changed

+122
-19
lines changed

4 files changed

+122
-19
lines changed

llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -192,63 +192,92 @@ m_GFCstOrSplat(std::optional<FPValueAndVReg> &FPValReg) {
192192

193193
/// Matcher for a specific constant value.
194194
struct SpecificConstantMatch {
195-
int64_t RequestedVal;
196-
SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {}
195+
APInt RequestedVal;
196+
SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
197197
bool match(const MachineRegisterInfo &MRI, Register Reg) {
198-
int64_t MatchedVal;
199-
return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal;
198+
APInt MatchedVal;
199+
if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
200+
if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
201+
RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
202+
else
203+
MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
204+
205+
return APInt::isSameValue(MatchedVal, RequestedVal);
206+
}
207+
return false;
200208
}
201209
};
202210

203211
/// Matches a constant equal to \p RequestedValue.
212+
inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) {
213+
return SpecificConstantMatch(std::move(RequestedValue));
214+
}
215+
204216
inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
205-
return SpecificConstantMatch(RequestedValue);
217+
return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true));
206218
}
207219

208220
/// Matcher for a specific constant splat.
209221
struct SpecificConstantSplatMatch {
210-
int64_t RequestedVal;
211-
SpecificConstantSplatMatch(int64_t RequestedVal)
212-
: RequestedVal(RequestedVal) {}
222+
APInt RequestedVal;
223+
SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
213224
bool match(const MachineRegisterInfo &MRI, Register Reg) {
214225
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
215226
/* AllowUndef */ false);
216227
}
217228
};
218229

219230
/// Matches a constant splat of \p RequestedValue.
231+
inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) {
232+
return SpecificConstantSplatMatch(std::move(RequestedValue));
233+
}
234+
220235
inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
221-
return SpecificConstantSplatMatch(RequestedValue);
236+
return SpecificConstantSplatMatch(
237+
APInt(64, RequestedValue, /* isSigned */ true));
222238
}
223239

224240
/// Matcher for a specific constant or constant splat.
225241
struct SpecificConstantOrSplatMatch {
226-
int64_t RequestedVal;
227-
SpecificConstantOrSplatMatch(int64_t RequestedVal)
242+
APInt RequestedVal;
243+
SpecificConstantOrSplatMatch(APInt RequestedVal)
228244
: RequestedVal(RequestedVal) {}
229245
bool match(const MachineRegisterInfo &MRI, Register Reg) {
230-
int64_t MatchedVal;
231-
if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
232-
return true;
246+
APInt MatchedVal;
247+
if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
248+
if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
249+
RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
250+
else
251+
MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
252+
253+
if (APInt::isSameValue(MatchedVal, RequestedVal))
254+
return true;
255+
}
233256
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
234257
/* AllowUndef */ false);
235258
}
236259
};
237260

238261
/// Matches a \p RequestedValue constant or a constant splat of \p
239262
/// RequestedValue.
263+
inline SpecificConstantOrSplatMatch
264+
m_SpecificICstOrSplat(APInt RequestedValue) {
265+
return SpecificConstantOrSplatMatch(std::move(RequestedValue));
266+
}
267+
240268
inline SpecificConstantOrSplatMatch
241269
m_SpecificICstOrSplat(int64_t RequestedValue) {
242-
return SpecificConstantOrSplatMatch(RequestedValue);
270+
return SpecificConstantOrSplatMatch(
271+
APInt(64, RequestedValue, /* isSigned */ true));
243272
}
244273

245-
///{
246274
/// Convenience matchers for specific integer values.
247-
inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }
275+
inline SpecificConstantMatch m_ZeroInt() {
276+
return SpecificConstantMatch(APInt(64, 0));
277+
}
248278
inline SpecificConstantMatch m_AllOnesInt() {
249-
return SpecificConstantMatch(-1);
279+
return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true));
250280
}
251-
///}
252281

253282
/// Matcher for a specific register.
254283
struct SpecificRegisterMatch {

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
459459
const MachineRegisterInfo &MRI,
460460
int64_t SplatValue, bool AllowUndef);
461461

462+
/// Return true if the specified register is defined by G_BUILD_VECTOR or
463+
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
464+
LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
465+
const MachineRegisterInfo &MRI,
466+
APInt SplatValue, bool AllowUndef);
467+
462468
/// Return true if the specified instruction is a G_BUILD_VECTOR or
463469
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
464470
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
465471
const MachineRegisterInfo &MRI,
466472
int64_t SplatValue, bool AllowUndef);
467473

474+
/// Return true if the specified instruction is a G_BUILD_VECTOR or
475+
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
476+
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
477+
const MachineRegisterInfo &MRI,
478+
APInt SplatValue, bool AllowUndef);
479+
468480
/// Return true if the specified instruction is a G_BUILD_VECTOR or
469481
/// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
470482
LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI,

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,13 +1401,35 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg,
14011401
return false;
14021402
}
14031403

1404+
bool llvm::isBuildVectorConstantSplat(const Register Reg,
1405+
const MachineRegisterInfo &MRI,
1406+
APInt SplatValue, bool AllowUndef) {
1407+
if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) {
1408+
if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth())
1409+
return APInt::isSameValue(
1410+
SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue);
1411+
return APInt::isSameValue(
1412+
SplatValAndReg->Value,
1413+
SplatValue.sext(SplatValAndReg->Value.getBitWidth()));
1414+
}
1415+
1416+
return false;
1417+
}
1418+
14041419
bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
14051420
const MachineRegisterInfo &MRI,
14061421
int64_t SplatValue, bool AllowUndef) {
14071422
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
14081423
AllowUndef);
14091424
}
14101425

1426+
bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
1427+
const MachineRegisterInfo &MRI,
1428+
APInt SplatValue, bool AllowUndef) {
1429+
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
1430+
AllowUndef);
1431+
}
1432+
14111433
std::optional<APInt>
14121434
llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) {
14131435
if (auto SplatValAndReg =

llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) {
634634
auto FortyTwo = B.buildConstant(LLT::scalar(64), 42);
635635
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42)));
636636
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123)));
637+
EXPECT_TRUE(
638+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42))));
639+
EXPECT_FALSE(
640+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123))));
637641

638642
// Test that this works inside of a more complex pattern.
639643
LLT s64 = LLT::scalar(64);
640644
auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo);
641645
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42)));
646+
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42))));
642647

643648
// Wrong constant.
644649
EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123)));
650+
EXPECT_FALSE(
651+
mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123))));
645652

646653
// No constant on the LHS.
647654
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
655+
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42))));
648656
}
649657

650658
TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
@@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
664672
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
665673
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));
666674

675+
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
676+
m_SpecificICstSplat(APInt(64, 42))));
677+
EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
678+
m_SpecificICstSplat(APInt(64, 43))));
679+
EXPECT_FALSE(
680+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42))));
681+
667682
MachineInstrBuilder NonConstantSplat =
668683
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
669684

@@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
673688
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
674689
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));
675690

691+
EXPECT_TRUE(
692+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
693+
EXPECT_FALSE(
694+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43))));
695+
EXPECT_FALSE(
696+
mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42))));
697+
676698
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
677699
EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
700+
EXPECT_FALSE(
701+
mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
678702
}
679703

680704
TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
@@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
695719
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
696720
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
697721

722+
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
723+
m_SpecificICstOrSplat(APInt(64, 42))));
724+
EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
725+
m_SpecificICstOrSplat(APInt(64, 43))));
726+
EXPECT_TRUE(
727+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
728+
698729
MachineInstrBuilder NonConstantSplat =
699730
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
700731

@@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
704735
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
705736
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));
706737

738+
EXPECT_TRUE(
739+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
740+
EXPECT_FALSE(
741+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43))));
742+
EXPECT_FALSE(
743+
mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
744+
707745
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
708746
EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
747+
EXPECT_TRUE(
748+
mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
709749
}
710750

711751
TEST_F(AArch64GISelMITest, MatchZeroInt) {

0 commit comments

Comments
 (0)