Skip to content

Commit 815edc3

Browse files
[SelectionDAG] Recurse through mask expression trees in WidenVSELECTMask (llvm#188085)
WidenVSELECTMask currently handles only two mask shapes: a bare SETCC or a single AND/OR/XOR of exactly two SETCCs. Anything deeper bails out to the generic condition widening path, which often introduces unnecessary narrow/widen roundtrips (xtn+sshll on AArch64, packssdw+vpmovsxwd on X86). Replace the hand-coded cases with a recursive widenMaskTree that walks through SETCC, AND/OR/XOR, FREEZE, VECTOR_SHUFFLE, SELECT/VSELECT, and all-ones/all-zeros BUILD_VECTORs.
1 parent 8c88fae commit 815edc3

File tree

5 files changed

+387
-103
lines changed

5 files changed

+387
-103
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,10 +1151,24 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
11511151
/// By default, the vector will be widened with undefined values.
11521152
SDValue ModifyToType(SDValue InOp, EVT NVT, bool FillWithZeroes = false);
11531153

1154+
/// Adjust element width (sign-extend/truncate) and element count
1155+
/// (extract/concat) of Mask to match ToMaskVT.
1156+
SDValue adjustMaskToType(SDValue Mask, EVT ToMaskVT);
1157+
1158+
/// Pick an intermediate VT and adjust both operands to it, minimizing
1159+
/// extend/truncate overhead given the final target ToVT.
1160+
EVT unifyMaskTypes(SDValue &Op0, SDValue &Op1, EVT ToVT);
1161+
11541162
/// Return a mask of vector type MaskVT to replace InMask. Also adjust
11551163
/// MaskVT to ToMaskVT if needed with vector extension or truncation.
11561164
SDValue convertMask(SDValue InMask, EVT MaskVT, EVT ToMaskVT);
11571165

1166+
/// Recursively convert a mask expression tree to ToVT, walking through
1167+
/// mask-preserving operations down to SETCC leaves. Avoids redundant
1168+
/// extend/truncate chains that arise when each node is converted
1169+
/// independently. Returns SDValue() if the tree cannot be converted.
1170+
SDValue convertMaskTree(SDValue V, EVT ToVT, unsigned Depth = 0);
1171+
11581172
//===--------------------------------------------------------------------===//
11591173
// Generic Splitting: LegalizeTypesGeneric.cpp
11601174
//===--------------------------------------------------------------------===//

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 130 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6873,9 +6873,8 @@ static inline bool isSETCCorConvertedSETCC(SDValue N) {
68736873
// to ToMaskVT if needed with vector extension or truncation.
68746874
SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
68756875
EVT ToMaskVT) {
6876-
// Currently a SETCC or a AND/OR/XOR with two SETCCs are handled.
6877-
// FIXME: This code seems to be too restrictive, we might consider
6878-
// generalizing it or dropping it.
6876+
// Called from convertMaskTree for SETCC leaf nodes. Re-creates the SETCC with
6877+
// result type MaskVT, then sign-extends/truncates and pads to ToMaskVT.
68796878
assert(isSETCCorConvertedSETCC(InMask) && "Unexpected mask argument.");
68806879

68816880
// Make a new Mask node, with a legal result VT.
@@ -6892,9 +6891,14 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
68926891
Mask = DAG.getNode(InMask->getOpcode(), SDLoc(InMask), MaskVT, Ops,
68936892
InMask->getFlags());
68946893

6895-
// If MaskVT has smaller or bigger elements than ToMaskVT, a vector sign
6896-
// extend or truncate is needed.
6894+
return adjustMaskToType(Mask, ToMaskVT);
6895+
}
6896+
6897+
// Adjust element width (sign-extend/truncate) and element count
6898+
// (extract/concat) of Mask to match ToMaskVT.
6899+
SDValue DAGTypeLegalizer::adjustMaskToType(SDValue Mask, EVT ToMaskVT) {
68976900
LLVMContext &Ctx = *DAG.getContext();
6901+
EVT MaskVT = Mask.getValueType();
68986902
unsigned MaskScalarBits = MaskVT.getScalarSizeInBits();
68996903
unsigned ToMaskScalBits = ToMaskVT.getScalarSizeInBits();
69006904
if (MaskScalarBits < ToMaskScalBits) {
@@ -6929,6 +6933,125 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
69296933
return Mask;
69306934
}
69316935

6936+
// Adjust both operands to a common intermediate mask type, picking a scalar
6937+
// width that minimizes extend/truncate overhead given the final target ToVT.
6938+
EVT DAGTypeLegalizer::unifyMaskTypes(SDValue &Op0, SDValue &Op1, EVT ToVT) {
6939+
assert(Op0.getValueType().getVectorNumElements() ==
6940+
Op1.getValueType().getVectorNumElements() &&
6941+
"unifyMaskTypes only handles scalar width differences");
6942+
unsigned Bits0 = Op0.getValueType().getScalarSizeInBits();
6943+
unsigned Bits1 = Op1.getValueType().getScalarSizeInBits();
6944+
unsigned NarrowBits = std::min(Bits0, Bits1);
6945+
unsigned WideBits = std::max(Bits0, Bits1);
6946+
unsigned ToBits = ToVT.getScalarSizeInBits();
6947+
unsigned IntBits = NarrowBits == WideBits ? NarrowBits
6948+
: ToBits >= WideBits ? WideBits
6949+
: ToBits <= NarrowBits ? NarrowBits
6950+
: ToBits;
6951+
EVT OpVT = EVT::getVectorVT(*DAG.getContext(), MVT::getIntegerVT(IntBits),
6952+
Op0.getValueType().getVectorNumElements());
6953+
Op0 = adjustMaskToType(Op0, OpVT);
6954+
Op1 = adjustMaskToType(Op1, OpVT);
6955+
return OpVT;
6956+
}
6957+
6958+
SDValue DAGTypeLegalizer::convertMaskTree(SDValue V, EVT ToVT, unsigned Depth) {
6959+
if (Depth >= DAG.MaxRecursionDepth)
6960+
return SDValue();
6961+
6962+
SDValue Result = [&]() -> SDValue {
6963+
unsigned Opcode = V.getOpcode();
6964+
6965+
// Base case: SETCC produces the mask at its natural type.
6966+
if (isSETCCOp(Opcode)) {
6967+
EVT MaskVT = getSetCCResultType(getSETCCOperandType(V));
6968+
return convertMask(V, MaskVT, MaskVT);
6969+
}
6970+
6971+
// Base case: all-zeros or all-ones BUILD_VECTOR. Use ToVT directly since
6972+
// these are invariant under sign-extend/truncate.
6973+
if (ISD::isBuildVectorAllZeros(V.getNode()))
6974+
return DAG.getConstant(0, SDLoc(V), ToVT);
6975+
if (ISD::isBuildVectorAllOnes(V.getNode()))
6976+
return DAG.getAllOnesConstant(SDLoc(V), ToVT);
6977+
6978+
SDLoc DL(V);
6979+
6980+
// Logical operations (AND/OR/XOR): try picking the best fitting width out
6981+
// of children's element widths.
6982+
if (isLogicalMaskOp(Opcode)) {
6983+
SDValue Op0 = convertMaskTree(V.getOperand(0), ToVT, Depth + 1);
6984+
if (!Op0)
6985+
return SDValue();
6986+
SDValue Op1 = convertMaskTree(V.getOperand(1), ToVT, Depth + 1);
6987+
if (!Op1)
6988+
return SDValue();
6989+
EVT OpVT = unifyMaskTypes(Op0, Op1, ToVT);
6990+
return DAG.getNode(Opcode, DL, OpVT, Op0, Op1);
6991+
}
6992+
6993+
// FREEZE: widen the operand and re-wrap.
6994+
if (Opcode == ISD::FREEZE) {
6995+
SDValue Inner = convertMaskTree(V.getOperand(0), ToVT, Depth + 1);
6996+
if (!Inner)
6997+
return SDValue();
6998+
return DAG.getNode(ISD::FREEZE, DL, Inner.getValueType(), Inner);
6999+
}
7000+
7001+
// Vector shuffle: try inferring the best fitting width from operands.
7002+
if (Opcode == ISD::VECTOR_SHUFFLE) {
7003+
// Bail out when the number of elements is different, we can't
7004+
// simply reuse shuffle mask in this case.
7005+
if (V.getValueType().getVectorNumElements() !=
7006+
ToVT.getVectorNumElements())
7007+
return SDValue();
7008+
7009+
auto *Shuf = cast<ShuffleVectorSDNode>(V);
7010+
SDValue Op0 = convertMaskTree(V.getOperand(0), ToVT, Depth + 1);
7011+
if (!Op0)
7012+
return SDValue();
7013+
if (V.getOperand(1).isUndef()) {
7014+
EVT OpVT = Op0.getValueType();
7015+
return DAG.getVectorShuffle(OpVT, DL, Op0, DAG.getUNDEF(OpVT),
7016+
Shuf->getMask());
7017+
}
7018+
SDValue Op1 = convertMaskTree(V.getOperand(1), ToVT, Depth + 1);
7019+
if (!Op1)
7020+
return SDValue();
7021+
EVT OpVT = unifyMaskTypes(Op0, Op1, ToVT);
7022+
return DAG.getVectorShuffle(OpVT, DL, Op0, Op1, Shuf->getMask());
7023+
}
7024+
7025+
// SELECT/VSELECT: try inferring the best fitting width from operands.
7026+
if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
7027+
SDValue Op1 = convertMaskTree(V.getOperand(1), ToVT, Depth + 1);
7028+
if (!Op1)
7029+
return SDValue();
7030+
SDValue Op2 = convertMaskTree(V.getOperand(2), ToVT, Depth + 1);
7031+
if (!Op2)
7032+
return SDValue();
7033+
EVT OpVT = unifyMaskTypes(Op1, Op2, ToVT);
7034+
7035+
SDValue Cond = V.getOperand(0);
7036+
if (Opcode == ISD::VSELECT) {
7037+
Cond = convertMaskTree(Cond, ToVT, Depth + 1);
7038+
if (!Cond)
7039+
return SDValue();
7040+
Cond = adjustMaskToType(Cond, OpVT);
7041+
}
7042+
return DAG.getNode(Opcode, DL, OpVT, Cond, Op1, Op2);
7043+
}
7044+
7045+
return SDValue();
7046+
}();
7047+
7048+
if (!Result)
7049+
return SDValue();
7050+
if (Depth == 0)
7051+
Result = adjustMaskToType(Result, ToVT);
7052+
return Result;
7053+
}
7054+
69327055
// This method tries to handle some special cases for the vselect mask
69337056
// and if needed adjusting the mask vector type to match that of the VSELECT.
69347057
// Without it, many cases end up with scalarization of the SETCC, with many
@@ -6940,9 +7063,6 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
69407063
if (N->getOpcode() != ISD::VSELECT)
69417064
return SDValue();
69427065

6943-
if (!isSETCCOp(Cond->getOpcode()) && !isLogicalMaskOp(Cond->getOpcode()))
6944-
return SDValue();
6945-
69467066
// If this is a splitted VSELECT that was previously already handled, do
69477067
// nothing.
69487068
EVT CondVT = Cond->getValueType(0);
@@ -6995,49 +7115,8 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
69957115
if (!ToMaskVT.getScalarType().isInteger())
69967116
ToMaskVT = ToMaskVT.changeVectorElementTypeToInteger();
69977117

6998-
SDValue Mask;
6999-
if (isSETCCOp(Cond->getOpcode())) {
7000-
EVT MaskVT = getSetCCResultType(getSETCCOperandType(Cond));
7001-
Mask = convertMask(Cond, MaskVT, ToMaskVT);
7002-
} else if (isLogicalMaskOp(Cond->getOpcode()) &&
7003-
isSETCCOp(Cond->getOperand(0).getOpcode()) &&
7004-
isSETCCOp(Cond->getOperand(1).getOpcode())) {
7005-
// Cond is (AND/OR/XOR (SETCC, SETCC))
7006-
SDValue SETCC0 = Cond->getOperand(0);
7007-
SDValue SETCC1 = Cond->getOperand(1);
7008-
EVT VT0 = getSetCCResultType(getSETCCOperandType(SETCC0));
7009-
EVT VT1 = getSetCCResultType(getSETCCOperandType(SETCC1));
7010-
unsigned ScalarBits0 = VT0.getScalarSizeInBits();
7011-
unsigned ScalarBits1 = VT1.getScalarSizeInBits();
7012-
unsigned ScalarBits_ToMask = ToMaskVT.getScalarSizeInBits();
7013-
EVT MaskVT;
7014-
// If the two SETCCs have different VTs, either extend/truncate one of
7015-
// them to the other "towards" ToMaskVT, or truncate one and extend the
7016-
// other to ToMaskVT.
7017-
if (ScalarBits0 != ScalarBits1) {
7018-
EVT NarrowVT = ((ScalarBits0 < ScalarBits1) ? VT0 : VT1);
7019-
EVT WideVT = ((NarrowVT == VT0) ? VT1 : VT0);
7020-
if (ScalarBits_ToMask >= WideVT.getScalarSizeInBits())
7021-
MaskVT = WideVT;
7022-
else if (ScalarBits_ToMask <= NarrowVT.getScalarSizeInBits())
7023-
MaskVT = NarrowVT;
7024-
else
7025-
MaskVT = ToMaskVT;
7026-
} else
7027-
// If the two SETCCs have the same VT, don't change it.
7028-
MaskVT = VT0;
7029-
7030-
// Make new SETCCs and logical nodes.
7031-
SETCC0 = convertMask(SETCC0, VT0, MaskVT);
7032-
SETCC1 = convertMask(SETCC1, VT1, MaskVT);
7033-
Cond = DAG.getNode(Cond->getOpcode(), SDLoc(Cond), MaskVT, SETCC0, SETCC1);
7034-
7035-
// Convert the logical op for VSELECT if needed.
7036-
Mask = convertMask(Cond, MaskVT, ToMaskVT);
7037-
} else
7038-
return SDValue();
7039-
7040-
return Mask;
7118+
// Try to recursively widen the mask expression tree to the target type.
7119+
return convertMaskTree(Cond, ToMaskVT);
70417120
}
70427121

70437122
SDValue DAGTypeLegalizer::WidenVecRes_Select(SDNode *N) {

llvm/test/CodeGen/AArch64/arm64-zip.ll

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,10 @@ define <4 x float> @shuffle_zip1(<4 x float> %arg) {
382382
; CHECK-LABEL: shuffle_zip1:
383383
; CHECK: // %bb.0: // %bb
384384
; CHECK-NEXT: fcmgt.4s v0, v0, #0.0
385-
; CHECK-NEXT: uzp1.8h v1, v0, v0
386-
; CHECK-NEXT: xtn.4h v0, v0
387-
; CHECK-NEXT: xtn.4h v1, v1
388-
; CHECK-NEXT: zip2.4h v0, v0, v1
389385
; CHECK-NEXT: fmov.4s v1, #1.00000000
390-
; CHECK-NEXT: zip1.4h v0, v0, v0
391-
; CHECK-NEXT: sshll.4s v0, v0, #0
386+
; CHECK-NEXT: uzp1.4s v2, v0, v0
387+
; CHECK-NEXT: zip2.4s v0, v0, v2
388+
; CHECK-NEXT: zip1.4s v0, v0, v0
392389
; CHECK-NEXT: and.16b v0, v0, v1
393390
; CHECK-NEXT: ret
394391
bb:
@@ -403,13 +400,10 @@ define <4 x i32> @shuffle_zip2(<4 x i32> %arg) {
403400
; CHECK-LABEL: shuffle_zip2:
404401
; CHECK: // %bb.0: // %bb
405402
; CHECK-NEXT: cmtst.4s v0, v0, v0
406-
; CHECK-NEXT: uzp1.8h v1, v0, v0
407-
; CHECK-NEXT: xtn.4h v0, v0
408-
; CHECK-NEXT: xtn.4h v1, v1
409-
; CHECK-NEXT: zip2.4h v0, v0, v1
410403
; CHECK-NEXT: movi.4s v1, #1
411-
; CHECK-NEXT: zip1.4h v0, v0, v0
412-
; CHECK-NEXT: ushll.4s v0, v0, #0
404+
; CHECK-NEXT: uzp1.4s v2, v0, v0
405+
; CHECK-NEXT: zip2.4s v0, v0, v2
406+
; CHECK-NEXT: zip1.4s v0, v0, v0
413407
; CHECK-NEXT: and.16b v0, v0, v1
414408
; CHECK-NEXT: ret
415409
bb:
@@ -424,13 +418,10 @@ define <4 x i32> @shuffle_zip3(<4 x i32> %arg) {
424418
; CHECK-LABEL: shuffle_zip3:
425419
; CHECK: // %bb.0: // %bb
426420
; CHECK-NEXT: cmgt.4s v0, v0, #0
427-
; CHECK-NEXT: uzp1.8h v1, v0, v0
428-
; CHECK-NEXT: xtn.4h v0, v0
429-
; CHECK-NEXT: xtn.4h v1, v1
430-
; CHECK-NEXT: zip2.4h v0, v0, v1
431421
; CHECK-NEXT: movi.4s v1, #1
432-
; CHECK-NEXT: zip1.4h v0, v0, v0
433-
; CHECK-NEXT: ushll.4s v0, v0, #0
422+
; CHECK-NEXT: uzp1.4s v2, v0, v0
423+
; CHECK-NEXT: zip2.4s v0, v0, v2
424+
; CHECK-NEXT: zip1.4s v0, v0, v0
434425
; CHECK-NEXT: and.16b v0, v0, v1
435426
; CHECK-NEXT: ret
436427
bb:

0 commit comments

Comments
 (0)