Skip to content

DAG: Move scalarizeExtractedVectorLoad to TargetLowering #122670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5622,6 +5622,18 @@ class TargetLowering : public TargetLoweringBase {
// joining their results. SDValue() is returned when expansion did not happen.
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;

/// Replace an extraction of a load with a narrowed load.
///
/// \param ResultVT type of the result extraction.
/// \param InVecVT type of the input vector to with bitcasts resolved.
/// \param EltNo index of the vector element to load.
/// \param OriginalLoad vector load that to be replaced.
/// \returns \p ResultVT Load on success SDValue() on failure.
SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
EVT InVecVT, SDValue EltNo,
LoadSDNode *OriginalLoad,
SelectionDAG &DAG) const;

private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
Expand Down
103 changes: 14 additions & 89 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,6 @@ namespace {
bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);

/// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
/// load.
///
/// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
/// \param InVecVT type of the input vector to EVE with bitcasts resolved.
/// \param EltNo index of the vector element to load.
/// \param OriginalLoad load that EVE came from to be replaced.
/// \returns EVE on success SDValue() on failure.
SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
SDValue EltNo,
LoadSDNode *OriginalLoad);
void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
Expand Down Expand Up @@ -22719,81 +22708,6 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
SDValue EltNo,
LoadSDNode *OriginalLoad) {
assert(OriginalLoad->isSimple());

EVT ResultVT = EVE->getValueType(0);
EVT VecEltVT = InVecVT.getVectorElementType();

// If the vector element type is not a multiple of a byte then we are unable
// to correctly compute an address to load only the extracted element as a
// scalar.
if (!VecEltVT.isByteSized())
return SDValue();

ISD::LoadExtType ExtTy =
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
return SDValue();

Align Alignment = OriginalLoad->getAlign();
MachinePointerInfo MPI;
SDLoc DL(EVE);
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
int Elt = ConstEltNo->getZExtValue();
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
Alignment = commonAlignment(Alignment, PtrOff);
} else {
// Discard the pointer info except the address space because the memory
// operand can't represent this new access since the offset is variable.
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
}

unsigned IsFast = 0;
if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
OriginalLoad->getAddressSpace(), Alignment,
OriginalLoad->getMemOperand()->getFlags(),
&IsFast) ||
!IsFast)
return SDValue();

SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
InVecVT, EltNo);

// We are replacing a vector load with a scalar load. The new load must have
// identical memory op ordering to the original.
SDValue Load;
if (ResultVT.bitsGT(VecEltVT)) {
// If the result type of vextract is wider than the load, then issue an
// extending load instead.
ISD::LoadExtType ExtType =
TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
: ISD::EXTLOAD;
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
NewPtr, MPI, VecEltVT, Alignment,
OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
} else {
// The result type is narrower or the same width as the vector element
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
Alignment, OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
if (ResultVT.bitsLT(VecEltVT))
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
else
Load = DAG.getBitcast(ResultVT, Load);
}
++OpsNarrowed;
return Load;
}

/// Transform a vector binary operation into a scalar binary operation by moving
/// the math/logic after an extract element of a vector.
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
Expand Down Expand Up @@ -23272,8 +23186,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
ISD::isNormalLoad(VecOp.getNode()) &&
!Index->hasPredecessor(VecOp.getNode())) {
auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
if (VecLoad && VecLoad->isSimple())
return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
if (VecLoad && VecLoad->isSimple()) {
if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
++OpsNarrowed;
return Scalarized;
}
}
}

// Perform only after legalization to ensure build_vector / vector_shuffle
Expand Down Expand Up @@ -23361,7 +23280,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
if (Elt == -1)
return DAG.getUNDEF(LVT);

return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
if (SDValue Scalarized =
TLI.scalarizeExtractedVectorLoad(LVT, DL, VecVT, Index, LN0, DAG)) {
++OpsNarrowed;
return Scalarized;
}

return SDValue();
}

// Simplify (build_vec (ext )) to (bitcast (build_vec ))
Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12114,3 +12114,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
}

SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
const SDLoc &DL,
EVT InVecVT, SDValue EltNo,
LoadSDNode *OriginalLoad,
SelectionDAG &DAG) const {
assert(OriginalLoad->isSimple());

EVT VecEltVT = InVecVT.getVectorElementType();

// If the vector element type is not a multiple of a byte then we are unable
// to correctly compute an address to load only the extracted element as a
// scalar.
if (!VecEltVT.isByteSized())
return SDValue();

ISD::LoadExtType ExtTy =
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
!shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
return SDValue();

Align Alignment = OriginalLoad->getAlign();
MachinePointerInfo MPI;
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
int Elt = ConstEltNo->getZExtValue();
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
Alignment = commonAlignment(Alignment, PtrOff);
} else {
// Discard the pointer info except the address space because the memory
// operand can't represent this new access since the offset is variable.
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
}

unsigned IsFast = 0;
if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
OriginalLoad->getAddressSpace(), Alignment,
OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
!IsFast)
return SDValue();

SDValue NewPtr =
getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);

// We are replacing a vector load with a scalar load. The new load must have
// identical memory op ordering to the original.
SDValue Load;
if (ResultVT.bitsGT(VecEltVT)) {
// If the result type of vextract is wider than the load, then issue an
// extending load instead.
ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
? ISD::ZEXTLOAD
: ISD::EXTLOAD;
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
NewPtr, MPI, VecEltVT, Alignment,
OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
} else {
// The result type is narrower or the same width as the vector element
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
Alignment, OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
if (ResultVT.bitsLT(VecEltVT))
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
else
Load = DAG.getBitcast(ResultVT, Load);
}

return Load;
}