Skip to content

Commit 16b5990

Browse files
committed
[AIE2P] Combine shuffle to broadcast + insertions
1 parent 4e36470 commit 16b5990

File tree

6 files changed

+439
-83
lines changed

6 files changed

+439
-83
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ def combine_paired_extracts : GICombineRule<
126126
[{ return matchPairedExtracts(*${root}, MRI, Helper, (const AIEBaseInstrInfo &)B.getTII(), ${matchinfo}); }]),
127127
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
128128

129+
def combine_vector_shuffle_to_extract_insert_elt_to_broadcast : GICombineRule<
130+
(defs root:$root, build_fn_matchinfo:$matchinfo),
131+
(match (wip_match_opcode G_SHUFFLE_VECTOR): $root,
132+
[{ return matchShuffleToExtractInsertEltToBroadcast(*${root}, MRI, ${matchinfo}); }]),
133+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
134+
129135
def AIE2PreLegalizerCombiner
130136
: GICombiner<"AIE2PreLegalizerCombinerImpl", [ combine_unpad_vector, combine_pad_vector,
131137
all_combines, combine_S20NarrowingOpt,
@@ -152,7 +158,8 @@ def AIE2PPreLegalizerCombiner
152158
combine_single_diff_build_vector,
153159
combine_vector_shuffle_to_extract_insert_elt,
154160
combine_vector_shuffle_concat_extracted_subvectors,
155-
combine_paired_extracts]> {
161+
combine_paired_extracts,
162+
combine_vector_shuffle_to_extract_insert_elt_to_broadcast]> {
156163
let CombineAllMethodName = "tryCombineAllImpl";
157164
}
158165

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 170 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ cl::opt<bool> CombineVecShiftByZero(
5757
"aie-combine-vec-shift-by-zero", cl::init(true), cl::Hidden,
5858
cl::desc("Combine vectors shift by zero into copies."));
5959

60+
static unsigned getNumMaskUndefs(const ArrayRef<int> &Mask,
61+
unsigned StartIndex) {
62+
unsigned Count = 0;
63+
for (unsigned I = StartIndex; I < Mask.size(); ++I) {
64+
if (Mask[I] == -1) {
65+
++Count;
66+
}
67+
}
68+
return Count;
69+
}
70+
6071
bool MaskMatch::isValidMask(const ArrayRef<int> Mask) const {
6172
for (unsigned Idx = 0; Idx < Mask.size(); ++Idx) {
6273
if (Mask[Idx] == -1)
@@ -136,6 +147,55 @@ std::optional<int> MaskMatch::getUniqueIndex(ArrayRef<int> Mask) {
136147
return UniqOpIdx;
137148
}
138149

150+
static std::unordered_map<int, unsigned>
151+
getMaskFrequencyMap(const ArrayRef<int> Mask) {
152+
assert(!MaskMatch::isMaskWithAllUndefs(Mask));
153+
std::unordered_map<int, unsigned> FrequencyMap;
154+
for (int Idx : Mask) {
155+
if (Idx == -1)
156+
continue;
157+
FrequencyMap[Idx]++;
158+
}
159+
return FrequencyMap;
160+
}
161+
162+
std::optional<FrequentIndexResult>
163+
MaskMatch::getFrequentIndexResult(const ArrayRef<int> Mask,
164+
unsigned MinFrequency = 0) {
165+
166+
// Set the default value for MinFrequency
167+
if (MinFrequency == 0) {
168+
MinFrequency = Mask.size() / 2;
169+
}
170+
171+
std::unordered_map<int, unsigned> FrequencyMap = getMaskFrequencyMap(Mask);
172+
unsigned DontCareCount = getNumMaskUndefs(Mask, 0);
173+
174+
auto [FrequentValue, HighestFrequency] = *std::max_element(
175+
FrequencyMap.begin(), FrequencyMap.end(),
176+
[](const std::pair<int, unsigned> p1, const std::pair<int, unsigned> p2) {
177+
return p1.second < p2.second;
178+
});
179+
180+
unsigned HighestAdjustedFrequency = HighestFrequency + DontCareCount;
181+
if (HighestAdjustedFrequency < MinFrequency) {
182+
return std::nullopt;
183+
}
184+
185+
unsigned NonMatchingCount = Mask.size() - HighestAdjustedFrequency;
186+
187+
unsigned FrequentIdx = 0;
188+
for (unsigned I = 0; I < Mask.size(); I++) {
189+
int MaskValue = Mask[I];
190+
if (MaskValue == FrequentValue) {
191+
FrequentIdx = I;
192+
break;
193+
}
194+
}
195+
196+
return FrequentIndexResult{FrequentIdx, NonMatchingCount};
197+
}
198+
139199
MachineInstr *findPreIncMatch(MachineInstr &MemI, MachineRegisterInfo &MRI,
140200
CombinerHelper &Helper,
141201
AIELoadStoreCombineMatchData &MatchData,
@@ -1903,17 +1963,6 @@ static bool isPowerOfTwoOrZero(unsigned Height) {
19031963
return Height == 0 || (Height > 1 && has_single_bit(Height));
19041964
}
19051965

1906-
static unsigned getNumMaskUndefs(const ArrayRef<int> &Mask,
1907-
unsigned StartIndex) {
1908-
unsigned Count = 0;
1909-
for (unsigned I = StartIndex; I < Mask.size(); ++I) {
1910-
if (Mask[I] == -1) {
1911-
++Count;
1912-
}
1913-
}
1914-
return Count;
1915-
}
1916-
19171966
/// \returns true if it is possible to combine the below sequence of MIRs
19181967
/// into a COPY.
19191968
/// From : %1:_(<64 x s8>) = G_IMPLICIT_DEF
@@ -2592,6 +2641,114 @@ bool llvm::matchShuffleToCopy(MachineInstr &MI, MachineRegisterInfo &MRI,
25922641
return true;
25932642
}
25942643

2644+
/// Match something like this:
2645+
/// %2:_(<16 x s32>) = G_SHUFFLE_VECTOR %0(<16 x s32>), %1(<16 x s32>),
2646+
/// shufflemask(16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
2647+
2648+
/// To convert to:
2649+
// %4:_(s32) = G_CONSTANT i32 0
2650+
// %5:_(s32) = G_EXTRACT_VECTOR_ELT %0(<16 x s32>), %4(s32)
2651+
// %3:_(<16 x s32>) = G_AIE_BROADCAST_VECTOR %5(s32)ab
2652+
// %6:_(s32) = G_EXTRACT_VECTOR_ELT %1(<16 x s32>), %4(s32)
2653+
// %2:_(<16 x s32>) = G_INSERT_VECTOR_ELT %3, %6(s32), %4(s32)
2654+
bool llvm::matchShuffleToExtractInsertEltToBroadcast(MachineInstr &MI,
2655+
MachineRegisterInfo &MRI,
2656+
BuildFnTy &MatchInfo) {
2657+
2658+
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2659+
2660+
const Register DstReg = MI.getOperand(0).getReg();
2661+
const Register Src1Reg = MI.getOperand(1).getReg();
2662+
const Register Src2Reg = MI.getOperand(2).getReg();
2663+
ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
2664+
2665+
const LLT DstTy = MRI.getType(DstReg);
2666+
const LLT Src1Ty = MRI.getType(Src1Reg);
2667+
if (DstTy != Src1Ty)
2668+
return false;
2669+
2670+
if (!DstTy.isVector() || !Src1Ty.isVector())
2671+
return false;
2672+
2673+
if (DstTy.getSizeInBits() < 128)
2674+
return false;
2675+
2676+
const unsigned NumSrcElems = Src1Ty.getNumElements();
2677+
const LLT DstElemTy = MRI.getType(Src1Reg).getElementType();
2678+
2679+
if (Mask.size() != NumSrcElems)
2680+
return false;
2681+
2682+
if (MaskMatch::isMaskWithAllUndefs(Mask))
2683+
return false;
2684+
2685+
unsigned MinFrequency;
2686+
if (MI.getMF()->getTarget().getTargetTriple().isAIE2P())
2687+
// The scalarization of G_SHUFFLE_VECTOR in the legalizer is more beneficial
2688+
// if there are more exceptions than NumSrcElems / 2 as AIE2P's VINSERT
2689+
// instrutions require a move to a register used for the index unlike VPUSH.
2690+
MinFrequency = (ShuffleMaxNumInsertions != 0) ? ShuffleMaxNumInsertions
2691+
: NumSrcElems / 2;
2692+
else
2693+
llvm_unreachable("MinFrequency unimplemented for target.");
2694+
2695+
std::optional<FrequentIndexResult> FrequentIdxResult =
2696+
MaskMatch::getFrequentIndexResult(Mask, MinFrequency);
2697+
2698+
if (!FrequentIdxResult)
2699+
return false;
2700+
2701+
unsigned FrequentIdx = FrequentIdxResult->FrequentIdx;
2702+
unsigned NonMatchingCount = FrequentIdxResult->NonMatchingCount;
2703+
2704+
// This is a pure broadcast pattern. Should be handled by
2705+
// matchShuffleToVecEltBroadcast combine
2706+
if (NonMatchingCount == 0)
2707+
return false;
2708+
2709+
int BcstValue = Mask[FrequentIdx];
2710+
2711+
MatchInfo = [=, &MRI](MachineIRBuilder &B) {
2712+
Register BroadcastVecReg = MRI.createGenericVirtualRegister(Src1Ty);
2713+
Register VecToExtract = BcstValue < (int)NumSrcElems ? Src1Reg : Src2Reg;
2714+
auto Extr = B.buildExtractVectorElementConstant(DstElemTy, VecToExtract,
2715+
BcstValue % NumSrcElems);
2716+
buildBroadcastVector(B, MRI, Extr.getReg(0), BroadcastVecReg);
2717+
2718+
Register InsertSrc = BroadcastVecReg;
2719+
Register InsertDst;
2720+
2721+
unsigned InsertionCount = 0;
2722+
for (unsigned Idx = 0; Idx < Mask.size(); ++Idx) {
2723+
2724+
if (Mask[Idx] == BcstValue)
2725+
continue;
2726+
2727+
if (Mask[Idx] == -1)
2728+
continue;
2729+
2730+
Register VecToExtract = Mask[Idx] < (int)NumSrcElems ? Src1Reg : Src2Reg;
2731+
2732+
int ExtractIdx = Mask[Idx] % NumSrcElems;
2733+
auto ExtrElt = B.buildExtractVectorElementConstant(
2734+
DstElemTy, VecToExtract, ExtractIdx);
2735+
2736+
auto NonMatchingIdxReg = B.buildConstant(LLT::scalar(32), Idx);
2737+
2738+
InsertDst = (InsertionCount == NonMatchingCount - 1)
2739+
? DstReg
2740+
: MRI.createGenericVirtualRegister(Src1Ty);
2741+
2742+
B.buildInsertVectorElement(InsertDst, InsertSrc, ExtrElt,
2743+
NonMatchingIdxReg);
2744+
InsertSrc = InsertDst;
2745+
InsertionCount++;
2746+
}
2747+
};
2748+
2749+
return true;
2750+
}
2751+
25952752
/// Match something like this:
25962753
/// %0:_(<32 x s16>) = COPY $x0
25972754
/// %1:_(<32 x s16>) = COPY $x1
@@ -2625,7 +2782,8 @@ bool llvm::matchShuffleToExtractInsertElt(MachineInstr &MI,
26252782
if (MI.getMF()->getTarget().getTargetTriple().isAIE2P())
26262783
// The scalarization of G_SHUFFLE_VECTOR in the legalizer is more beneficial
26272784
// if there are more exceptions than NumSrcElems / 2 as AIE2P's VINSERT
2628-
// instrutions require a move to a register used for the index unlike VPUSH.
2785+
// instructions require a move to a register used for the index unlike
2786+
// VPUSH.
26292787
MaxNumInsertions = (ShuffleMaxNumInsertions != 0) ? ShuffleMaxNumInsertions
26302788
: NumSrcElems / 2;
26312789
else

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ struct ShuffleMaskValidity {
3939
SmallVector<unsigned, 4> MaskExceptions;
4040
};
4141

42+
struct FrequentIndexResult {
43+
unsigned FrequentIdx;
44+
unsigned NonMatchingCount;
45+
};
46+
4247
/// The mask is represented by a sawtooth function F with Period, Height and
4348
/// Amplitude, i.e., F(idx + Period) = F(idx) = Height + idx * Amplitude, where
4449
/// idx >= 0.
@@ -58,6 +63,8 @@ class MaskMatch {
5863
static std::optional<int> getUniqueIndex(ArrayRef<int> Mask);
5964
static bool isMaskWithinRangeOrUndef(ArrayRef<int> Mask, int MinValue,
6065
int MaxValue);
66+
static std::optional<FrequentIndexResult>
67+
getFrequentIndexResult(ArrayRef<int> Mask, unsigned MinFrequency);
6168

6269
unsigned getMaskValue(unsigned Idx) const {
6370
unsigned BaseIdx = Period == 0 ? Idx : Idx % Period;
@@ -283,6 +290,9 @@ bool matchPairedExtracts(MachineInstr &MI, MachineRegisterInfo &MRI,
283290
CombinerHelper &Helper, const TargetInstrInfo &TII,
284291
BuildFnTy &MatchInfo);
285292

293+
bool matchShuffleToExtractInsertEltToBroadcast(MachineInstr &MI,
294+
MachineRegisterInfo &MRI,
295+
BuildFnTy &MatchInfo);
286296
} // namespace llvm
287297

288298
#endif

0 commit comments

Comments
 (0)