Skip to content

Commit c55a765

Browse files
authored
DAG: Move scalarizeExtractedVectorLoad to TargetLowering (#122670)
SimplifyDemandedVectorElts should be able to use this on loads
1 parent 8fdd982 commit c55a765

File tree

3 files changed

+100
-89
lines changed

3 files changed

+100
-89
lines changed

Diff for: llvm/include/llvm/CodeGen/TargetLowering.h

+12
Original file line numberDiff line numberDiff line change
@@ -5622,6 +5622,18 @@ class TargetLowering : public TargetLoweringBase {
56225622
// joining their results. SDValue() is returned when expansion did not happen.
56235623
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
56245624

5625+
/// Replace an extraction of a load with a narrowed load.
5626+
///
5627+
/// \param ResultVT type of the result extraction.
5628+
/// \param InVecVT type of the input vector to with bitcasts resolved.
5629+
/// \param EltNo index of the vector element to load.
5630+
/// \param OriginalLoad vector load that to be replaced.
5631+
/// \returns \p ResultVT Load on success SDValue() on failure.
5632+
SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
5633+
EVT InVecVT, SDValue EltNo,
5634+
LoadSDNode *OriginalLoad,
5635+
SelectionDAG &DAG) const;
5636+
56255637
private:
56265638
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
56275639
const SDLoc &DL, DAGCombinerInfo &DCI) const;

Diff for: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

+14-89
Original file line numberDiff line numberDiff line change
@@ -385,17 +385,6 @@ namespace {
385385
bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
386386
bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
387387

388-
/// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
389-
/// load.
390-
///
391-
/// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
392-
/// \param InVecVT type of the input vector to EVE with bitcasts resolved.
393-
/// \param EltNo index of the vector element to load.
394-
/// \param OriginalLoad load that EVE came from to be replaced.
395-
/// \returns EVE on success SDValue() on failure.
396-
SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
397-
SDValue EltNo,
398-
LoadSDNode *OriginalLoad);
399388
void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
400389
SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
401390
SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
@@ -22719,81 +22708,6 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2271922708
return SDValue();
2272022709
}
2272122710

22722-
SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
22723-
SDValue EltNo,
22724-
LoadSDNode *OriginalLoad) {
22725-
assert(OriginalLoad->isSimple());
22726-
22727-
EVT ResultVT = EVE->getValueType(0);
22728-
EVT VecEltVT = InVecVT.getVectorElementType();
22729-
22730-
// If the vector element type is not a multiple of a byte then we are unable
22731-
// to correctly compute an address to load only the extracted element as a
22732-
// scalar.
22733-
if (!VecEltVT.isByteSized())
22734-
return SDValue();
22735-
22736-
ISD::LoadExtType ExtTy =
22737-
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
22738-
if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
22739-
!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
22740-
return SDValue();
22741-
22742-
Align Alignment = OriginalLoad->getAlign();
22743-
MachinePointerInfo MPI;
22744-
SDLoc DL(EVE);
22745-
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
22746-
int Elt = ConstEltNo->getZExtValue();
22747-
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
22748-
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
22749-
Alignment = commonAlignment(Alignment, PtrOff);
22750-
} else {
22751-
// Discard the pointer info except the address space because the memory
22752-
// operand can't represent this new access since the offset is variable.
22753-
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
22754-
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
22755-
}
22756-
22757-
unsigned IsFast = 0;
22758-
if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
22759-
OriginalLoad->getAddressSpace(), Alignment,
22760-
OriginalLoad->getMemOperand()->getFlags(),
22761-
&IsFast) ||
22762-
!IsFast)
22763-
return SDValue();
22764-
22765-
SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
22766-
InVecVT, EltNo);
22767-
22768-
// We are replacing a vector load with a scalar load. The new load must have
22769-
// identical memory op ordering to the original.
22770-
SDValue Load;
22771-
if (ResultVT.bitsGT(VecEltVT)) {
22772-
// If the result type of vextract is wider than the load, then issue an
22773-
// extending load instead.
22774-
ISD::LoadExtType ExtType =
22775-
TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
22776-
: ISD::EXTLOAD;
22777-
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
22778-
NewPtr, MPI, VecEltVT, Alignment,
22779-
OriginalLoad->getMemOperand()->getFlags(),
22780-
OriginalLoad->getAAInfo());
22781-
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
22782-
} else {
22783-
// The result type is narrower or the same width as the vector element
22784-
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
22785-
Alignment, OriginalLoad->getMemOperand()->getFlags(),
22786-
OriginalLoad->getAAInfo());
22787-
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
22788-
if (ResultVT.bitsLT(VecEltVT))
22789-
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
22790-
else
22791-
Load = DAG.getBitcast(ResultVT, Load);
22792-
}
22793-
++OpsNarrowed;
22794-
return Load;
22795-
}
22796-
2279722711
/// Transform a vector binary operation into a scalar binary operation by moving
2279822712
/// the math/logic after an extract element of a vector.
2279922713
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
@@ -23272,8 +23186,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2327223186
ISD::isNormalLoad(VecOp.getNode()) &&
2327323187
!Index->hasPredecessor(VecOp.getNode())) {
2327423188
auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
23275-
if (VecLoad && VecLoad->isSimple())
23276-
return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
23189+
if (VecLoad && VecLoad->isSimple()) {
23190+
if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
23191+
ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
23192+
++OpsNarrowed;
23193+
return Scalarized;
23194+
}
23195+
}
2327723196
}
2327823197

2327923198
// Perform only after legalization to ensure build_vector / vector_shuffle
@@ -23361,7 +23280,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2336123280
if (Elt == -1)
2336223281
return DAG.getUNDEF(LVT);
2336323282

23364-
return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
23283+
if (SDValue Scalarized =
23284+
TLI.scalarizeExtractedVectorLoad(LVT, DL, VecVT, Index, LN0, DAG)) {
23285+
++OpsNarrowed;
23286+
return Scalarized;
23287+
}
23288+
23289+
return SDValue();
2336523290
}
2336623291

2336723292
// Simplify (build_vec (ext )) to (bitcast (build_vec ))

Diff for: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

+74
Original file line numberDiff line numberDiff line change
@@ -12114,3 +12114,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
1211412114
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
1211512115
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
1211612116
}
12117+
12118+
SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
12119+
const SDLoc &DL,
12120+
EVT InVecVT, SDValue EltNo,
12121+
LoadSDNode *OriginalLoad,
12122+
SelectionDAG &DAG) const {
12123+
assert(OriginalLoad->isSimple());
12124+
12125+
EVT VecEltVT = InVecVT.getVectorElementType();
12126+
12127+
// If the vector element type is not a multiple of a byte then we are unable
12128+
// to correctly compute an address to load only the extracted element as a
12129+
// scalar.
12130+
if (!VecEltVT.isByteSized())
12131+
return SDValue();
12132+
12133+
ISD::LoadExtType ExtTy =
12134+
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
12135+
if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
12136+
!shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
12137+
return SDValue();
12138+
12139+
Align Alignment = OriginalLoad->getAlign();
12140+
MachinePointerInfo MPI;
12141+
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
12142+
int Elt = ConstEltNo->getZExtValue();
12143+
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
12144+
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
12145+
Alignment = commonAlignment(Alignment, PtrOff);
12146+
} else {
12147+
// Discard the pointer info except the address space because the memory
12148+
// operand can't represent this new access since the offset is variable.
12149+
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
12150+
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
12151+
}
12152+
12153+
unsigned IsFast = 0;
12154+
if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
12155+
OriginalLoad->getAddressSpace(), Alignment,
12156+
OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
12157+
!IsFast)
12158+
return SDValue();
12159+
12160+
SDValue NewPtr =
12161+
getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);
12162+
12163+
// We are replacing a vector load with a scalar load. The new load must have
12164+
// identical memory op ordering to the original.
12165+
SDValue Load;
12166+
if (ResultVT.bitsGT(VecEltVT)) {
12167+
// If the result type of vextract is wider than the load, then issue an
12168+
// extending load instead.
12169+
ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
12170+
? ISD::ZEXTLOAD
12171+
: ISD::EXTLOAD;
12172+
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
12173+
NewPtr, MPI, VecEltVT, Alignment,
12174+
OriginalLoad->getMemOperand()->getFlags(),
12175+
OriginalLoad->getAAInfo());
12176+
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12177+
} else {
12178+
// The result type is narrower or the same width as the vector element
12179+
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
12180+
Alignment, OriginalLoad->getMemOperand()->getFlags(),
12181+
OriginalLoad->getAAInfo());
12182+
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12183+
if (ResultVT.bitsLT(VecEltVT))
12184+
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
12185+
else
12186+
Load = DAG.getBitcast(ResultVT, Load);
12187+
}
12188+
12189+
return Load;
12190+
}

0 commit comments

Comments
 (0)