@@ -6873,9 +6873,8 @@ static inline bool isSETCCorConvertedSETCC(SDValue N) {
68736873// to ToMaskVT if needed with vector extension or truncation.
68746874SDValue DAGTypeLegalizer::convertMask (SDValue InMask, EVT MaskVT,
68756875 EVT ToMaskVT) {
6876- // Currently a SETCC or a AND/OR/XOR with two SETCCs are handled.
6877- // FIXME: This code seems to be too restrictive, we might consider
6878- // generalizing it or dropping it.
6876+ // Called from convertMaskTree for SETCC leaf nodes. Re-creates the SETCC with
6877+ // result type MaskVT, then sign-extends/truncates and pads to ToMaskVT.
68796878 assert (isSETCCorConvertedSETCC (InMask) && " Unexpected mask argument." );
68806879
68816880 // Make a new Mask node, with a legal result VT.
@@ -6892,9 +6891,14 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
68926891 Mask = DAG.getNode (InMask->getOpcode (), SDLoc (InMask), MaskVT, Ops,
68936892 InMask->getFlags ());
68946893
6895- // If MaskVT has smaller or bigger elements than ToMaskVT, a vector sign
6896- // extend or truncate is needed.
6894+ return adjustMaskToType (Mask, ToMaskVT);
6895+ }
6896+
6897+ // Adjust element width (sign-extend/truncate) and element count
6898+ // (extract/concat) of Mask to match ToMaskVT.
6899+ SDValue DAGTypeLegalizer::adjustMaskToType (SDValue Mask, EVT ToMaskVT) {
68976900 LLVMContext &Ctx = *DAG.getContext ();
6901+ EVT MaskVT = Mask.getValueType ();
68986902 unsigned MaskScalarBits = MaskVT.getScalarSizeInBits ();
68996903 unsigned ToMaskScalBits = ToMaskVT.getScalarSizeInBits ();
69006904 if (MaskScalarBits < ToMaskScalBits) {
@@ -6929,6 +6933,125 @@ SDValue DAGTypeLegalizer::convertMask(SDValue InMask, EVT MaskVT,
69296933 return Mask;
69306934}
69316935
6936+ // Adjust both operands to a common intermediate mask type, picking a scalar
6937+ // width that minimizes extend/truncate overhead given the final target ToVT.
6938+ EVT DAGTypeLegalizer::unifyMaskTypes (SDValue &Op0, SDValue &Op1, EVT ToVT) {
6939+ assert (Op0.getValueType ().getVectorNumElements () ==
6940+ Op1.getValueType ().getVectorNumElements () &&
6941+ " unifyMaskTypes only handles scalar width differences" );
6942+ unsigned Bits0 = Op0.getValueType ().getScalarSizeInBits ();
6943+ unsigned Bits1 = Op1.getValueType ().getScalarSizeInBits ();
6944+ unsigned NarrowBits = std::min (Bits0, Bits1);
6945+ unsigned WideBits = std::max (Bits0, Bits1);
6946+ unsigned ToBits = ToVT.getScalarSizeInBits ();
6947+ unsigned IntBits = NarrowBits == WideBits ? NarrowBits
6948+ : ToBits >= WideBits ? WideBits
6949+ : ToBits <= NarrowBits ? NarrowBits
6950+ : ToBits;
6951+ EVT OpVT = EVT::getVectorVT (*DAG.getContext (), MVT::getIntegerVT (IntBits),
6952+ Op0.getValueType ().getVectorNumElements ());
6953+ Op0 = adjustMaskToType (Op0, OpVT);
6954+ Op1 = adjustMaskToType (Op1, OpVT);
6955+ return OpVT;
6956+ }
6957+
6958+ SDValue DAGTypeLegalizer::convertMaskTree (SDValue V, EVT ToVT, unsigned Depth) {
6959+ if (Depth >= DAG.MaxRecursionDepth )
6960+ return SDValue ();
6961+
6962+ SDValue Result = [&]() -> SDValue {
6963+ unsigned Opcode = V.getOpcode ();
6964+
6965+ // Base case: SETCC produces the mask at its natural type.
6966+ if (isSETCCOp (Opcode)) {
6967+ EVT MaskVT = getSetCCResultType (getSETCCOperandType (V));
6968+ return convertMask (V, MaskVT, MaskVT);
6969+ }
6970+
6971+ // Base case: all-zeros or all-ones BUILD_VECTOR. Use ToVT directly since
6972+ // these are invariant under sign-extend/truncate.
6973+ if (ISD::isBuildVectorAllZeros (V.getNode ()))
6974+ return DAG.getConstant (0 , SDLoc (V), ToVT);
6975+ if (ISD::isBuildVectorAllOnes (V.getNode ()))
6976+ return DAG.getAllOnesConstant (SDLoc (V), ToVT);
6977+
6978+ SDLoc DL (V);
6979+
6980+ // Logical operations (AND/OR/XOR): try picking the best fitting width out
6981+ // of children's element widths.
6982+ if (isLogicalMaskOp (Opcode)) {
6983+ SDValue Op0 = convertMaskTree (V.getOperand (0 ), ToVT, Depth + 1 );
6984+ if (!Op0)
6985+ return SDValue ();
6986+ SDValue Op1 = convertMaskTree (V.getOperand (1 ), ToVT, Depth + 1 );
6987+ if (!Op1)
6988+ return SDValue ();
6989+ EVT OpVT = unifyMaskTypes (Op0, Op1, ToVT);
6990+ return DAG.getNode (Opcode, DL, OpVT, Op0, Op1);
6991+ }
6992+
6993+ // FREEZE: widen the operand and re-wrap.
6994+ if (Opcode == ISD::FREEZE) {
6995+ SDValue Inner = convertMaskTree (V.getOperand (0 ), ToVT, Depth + 1 );
6996+ if (!Inner)
6997+ return SDValue ();
6998+ return DAG.getNode (ISD::FREEZE, DL, Inner.getValueType (), Inner);
6999+ }
7000+
7001+ // Vector shuffle: try inferring the best fitting width from operands.
7002+ if (Opcode == ISD::VECTOR_SHUFFLE) {
7003+ // Bail out when the number of elements is different, we can't
7004+ // simply reuse shuffle mask in this case.
7005+ if (V.getValueType ().getVectorNumElements () !=
7006+ ToVT.getVectorNumElements ())
7007+ return SDValue ();
7008+
7009+ auto *Shuf = cast<ShuffleVectorSDNode>(V);
7010+ SDValue Op0 = convertMaskTree (V.getOperand (0 ), ToVT, Depth + 1 );
7011+ if (!Op0)
7012+ return SDValue ();
7013+ if (V.getOperand (1 ).isUndef ()) {
7014+ EVT OpVT = Op0.getValueType ();
7015+ return DAG.getVectorShuffle (OpVT, DL, Op0, DAG.getUNDEF (OpVT),
7016+ Shuf->getMask ());
7017+ }
7018+ SDValue Op1 = convertMaskTree (V.getOperand (1 ), ToVT, Depth + 1 );
7019+ if (!Op1)
7020+ return SDValue ();
7021+ EVT OpVT = unifyMaskTypes (Op0, Op1, ToVT);
7022+ return DAG.getVectorShuffle (OpVT, DL, Op0, Op1, Shuf->getMask ());
7023+ }
7024+
7025+ // SELECT/VSELECT: try inferring the best fitting width from operands.
7026+ if (Opcode == ISD::SELECT || Opcode == ISD::VSELECT) {
7027+ SDValue Op1 = convertMaskTree (V.getOperand (1 ), ToVT, Depth + 1 );
7028+ if (!Op1)
7029+ return SDValue ();
7030+ SDValue Op2 = convertMaskTree (V.getOperand (2 ), ToVT, Depth + 1 );
7031+ if (!Op2)
7032+ return SDValue ();
7033+ EVT OpVT = unifyMaskTypes (Op1, Op2, ToVT);
7034+
7035+ SDValue Cond = V.getOperand (0 );
7036+ if (Opcode == ISD::VSELECT) {
7037+ Cond = convertMaskTree (Cond, ToVT, Depth + 1 );
7038+ if (!Cond)
7039+ return SDValue ();
7040+ Cond = adjustMaskToType (Cond, OpVT);
7041+ }
7042+ return DAG.getNode (Opcode, DL, OpVT, Cond, Op1, Op2);
7043+ }
7044+
7045+ return SDValue ();
7046+ }();
7047+
7048+ if (!Result)
7049+ return SDValue ();
7050+ if (Depth == 0 )
7051+ Result = adjustMaskToType (Result, ToVT);
7052+ return Result;
7053+ }
7054+
69327055// This method tries to handle some special cases for the vselect mask
69337056// and if needed adjusting the mask vector type to match that of the VSELECT.
69347057// Without it, many cases end up with scalarization of the SETCC, with many
@@ -6940,9 +7063,6 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
69407063 if (N->getOpcode () != ISD::VSELECT)
69417064 return SDValue ();
69427065
6943- if (!isSETCCOp (Cond->getOpcode ()) && !isLogicalMaskOp (Cond->getOpcode ()))
6944- return SDValue ();
6945-
69467066 // If this is a splitted VSELECT that was previously already handled, do
69477067 // nothing.
69487068 EVT CondVT = Cond->getValueType (0 );
@@ -6995,49 +7115,8 @@ SDValue DAGTypeLegalizer::WidenVSELECTMask(SDNode *N) {
69957115 if (!ToMaskVT.getScalarType ().isInteger ())
69967116 ToMaskVT = ToMaskVT.changeVectorElementTypeToInteger ();
69977117
6998- SDValue Mask;
6999- if (isSETCCOp (Cond->getOpcode ())) {
7000- EVT MaskVT = getSetCCResultType (getSETCCOperandType (Cond));
7001- Mask = convertMask (Cond, MaskVT, ToMaskVT);
7002- } else if (isLogicalMaskOp (Cond->getOpcode ()) &&
7003- isSETCCOp (Cond->getOperand (0 ).getOpcode ()) &&
7004- isSETCCOp (Cond->getOperand (1 ).getOpcode ())) {
7005- // Cond is (AND/OR/XOR (SETCC, SETCC))
7006- SDValue SETCC0 = Cond->getOperand (0 );
7007- SDValue SETCC1 = Cond->getOperand (1 );
7008- EVT VT0 = getSetCCResultType (getSETCCOperandType (SETCC0));
7009- EVT VT1 = getSetCCResultType (getSETCCOperandType (SETCC1));
7010- unsigned ScalarBits0 = VT0.getScalarSizeInBits ();
7011- unsigned ScalarBits1 = VT1.getScalarSizeInBits ();
7012- unsigned ScalarBits_ToMask = ToMaskVT.getScalarSizeInBits ();
7013- EVT MaskVT;
7014- // If the two SETCCs have different VTs, either extend/truncate one of
7015- // them to the other "towards" ToMaskVT, or truncate one and extend the
7016- // other to ToMaskVT.
7017- if (ScalarBits0 != ScalarBits1) {
7018- EVT NarrowVT = ((ScalarBits0 < ScalarBits1) ? VT0 : VT1);
7019- EVT WideVT = ((NarrowVT == VT0) ? VT1 : VT0);
7020- if (ScalarBits_ToMask >= WideVT.getScalarSizeInBits ())
7021- MaskVT = WideVT;
7022- else if (ScalarBits_ToMask <= NarrowVT.getScalarSizeInBits ())
7023- MaskVT = NarrowVT;
7024- else
7025- MaskVT = ToMaskVT;
7026- } else
7027- // If the two SETCCs have the same VT, don't change it.
7028- MaskVT = VT0;
7029-
7030- // Make new SETCCs and logical nodes.
7031- SETCC0 = convertMask (SETCC0, VT0, MaskVT);
7032- SETCC1 = convertMask (SETCC1, VT1, MaskVT);
7033- Cond = DAG.getNode (Cond->getOpcode (), SDLoc (Cond), MaskVT, SETCC0, SETCC1);
7034-
7035- // Convert the logical op for VSELECT if needed.
7036- Mask = convertMask (Cond, MaskVT, ToMaskVT);
7037- } else
7038- return SDValue ();
7039-
7040- return Mask;
7118+ // Try to recursively widen the mask expression tree to the target type.
7119+ return convertMaskTree (Cond, ToMaskVT);
70417120}
70427121
70437122SDValue DAGTypeLegalizer::WidenVecRes_Select (SDNode *N) {
0 commit comments