Skip to content

Commit 4244a91

Browse files
authored
[SPIRV][NFC] Refactor pointer creation in GlobalRegistery (#134429)
This PR adds new interfaces to create pointer type, and adds some requirements to the old interfaces. This is the first step in #134119.
1 parent 2b3aa56 commit 4244a91

6 files changed

+137
-67
lines changed

Diff for: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

+3-11
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
215215
Argument *Arg = F.getArg(ArgIdx);
216216
Type *ArgType = Arg->getType();
217217
if (isTypedPointerTy(ArgType)) {
218-
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
219-
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
220-
SPIRV::AccessQualifier::ReadWrite, true);
221218
return GR->getOrCreateSPIRVPointerType(
222-
ElementType, MIRBuilder,
219+
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
223220
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
224221
}
225222

@@ -232,11 +229,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
232229
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
233230
// type.
234231
if (hasPointeeTypeAttr(Arg)) {
235-
SPIRVType *ElementType =
236-
GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder,
237-
SPIRV::AccessQualifier::ReadWrite, true);
238232
return GR->getOrCreateSPIRVPointerType(
239-
ElementType, MIRBuilder,
233+
getPointeeTypeByAttr(Arg), MIRBuilder,
240234
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
241235
}
242236

@@ -259,10 +253,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
259253
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
260254
Type *ElementTy =
261255
toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
262-
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
263-
ElementTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
264256
return GR->getOrCreateSPIRVPointerType(
265-
ElementType, MIRBuilder,
257+
ElementTy, MIRBuilder,
266258
addressSpaceToStorageClass(
267259
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
268260
}

Diff for: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

+74-10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,40 @@ static unsigned typeToAddressSpace(const Type *Ty) {
5454
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
5555
}
5656

57+
static bool
58+
storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
59+
switch (SC) {
60+
case SPIRV::StorageClass::Uniform:
61+
case SPIRV::StorageClass::PushConstant:
62+
case SPIRV::StorageClass::StorageBuffer:
63+
case SPIRV::StorageClass::PhysicalStorageBufferEXT:
64+
return true;
65+
case SPIRV::StorageClass::UniformConstant:
66+
case SPIRV::StorageClass::Input:
67+
case SPIRV::StorageClass::Output:
68+
case SPIRV::StorageClass::Workgroup:
69+
case SPIRV::StorageClass::CrossWorkgroup:
70+
case SPIRV::StorageClass::Private:
71+
case SPIRV::StorageClass::Function:
72+
case SPIRV::StorageClass::Generic:
73+
case SPIRV::StorageClass::AtomicCounter:
74+
case SPIRV::StorageClass::Image:
75+
case SPIRV::StorageClass::CallableDataNV:
76+
case SPIRV::StorageClass::IncomingCallableDataNV:
77+
case SPIRV::StorageClass::RayPayloadNV:
78+
case SPIRV::StorageClass::HitAttributeNV:
79+
case SPIRV::StorageClass::IncomingRayPayloadNV:
80+
case SPIRV::StorageClass::ShaderRecordBufferNV:
81+
case SPIRV::StorageClass::CodeSectionINTEL:
82+
case SPIRV::StorageClass::DeviceOnlyINTEL:
83+
case SPIRV::StorageClass::HostOnlyINTEL:
84+
return false;
85+
default:
86+
llvm_unreachable("Unknown storage class");
87+
return false;
88+
}
89+
}
90+
5791
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
5892
: PointerSize(PointerSize), Bound(0) {}
5993

@@ -1342,7 +1376,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
13421376
SPIRV::Decoration::NonWritable, 0, {});
13431377
}
13441378

1345-
SPIRVType *R = getOrCreateSPIRVPointerType(BlockType, MIRBuilder, SC);
1379+
SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BlockType, MIRBuilder, SC);
13461380
add(Key, R);
13471381
return R;
13481382
}
@@ -1524,7 +1558,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
15241558

15251559
// Handle "type*" or "type* vector[N]".
15261560
if (TypeStr.starts_with("*")) {
1527-
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1561+
SpirvTy = getOrCreateSPIRVPointerType(Ty, MIRBuilder, SC);
15281562
TypeStr = TypeStr.substr(strlen("*"));
15291563
}
15301564

@@ -1693,6 +1727,44 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
16931727
}
16941728

16951729
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1730+
const Type *BaseType, MachineInstr &I,
1731+
SPIRV::StorageClass::StorageClass SC) {
1732+
MachineIRBuilder MIRBuilder(I);
1733+
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1734+
}
1735+
1736+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1737+
const Type *BaseType, MachineIRBuilder &MIRBuilder,
1738+
SPIRV::StorageClass::StorageClass SC) {
1739+
SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
1740+
BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
1741+
return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
1742+
}
1743+
1744+
SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
1745+
SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
1746+
SPIRV::StorageClass::StorageClass OldSC = getPointerStorageClass(PtrType);
1747+
assert(storageClassRequiresExplictLayout(OldSC) ==
1748+
storageClassRequiresExplictLayout(SC));
1749+
1750+
SPIRVType *PointeeType = getPointeeType(PtrType);
1751+
MachineIRBuilder MIRBuilder(I);
1752+
return getOrCreateSPIRVPointerTypeInternal(PointeeType, MIRBuilder, SC);
1753+
}
1754+
1755+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1756+
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1757+
SPIRV::StorageClass::StorageClass SC) {
1758+
const Type *LLVMType = getTypeForSPIRVType(BaseType);
1759+
assert(!storageClassRequiresExplictLayout(SC));
1760+
SPIRVType *R = getOrCreateSPIRVPointerType(LLVMType, MIRBuilder, SC);
1761+
assert(
1762+
getPointeeType(R) == BaseType &&
1763+
"The base type was not correctly laid out for the given storage class.");
1764+
return R;
1765+
}
1766+
1767+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
16961768
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
16971769
SPIRV::StorageClass::StorageClass SC) {
16981770
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
@@ -1714,14 +1786,6 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
17141786
return finishCreatingSPIRVType(Ty, NewMI);
17151787
}
17161788

1717-
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1718-
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
1719-
SPIRV::StorageClass::StorageClass SC) {
1720-
MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
1721-
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1722-
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1723-
}
1724-
17251789
Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
17261790
SPIRVType *SpvType,
17271791
const SPIRVInstrInfo &TII) {

Diff for: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

+33-6
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
466466
Constant *CA, unsigned BitWidth,
467467
unsigned ElemCnt);
468468

469+
// Returns a pointer to a SPIR-V pointer type with the given base type and
470+
// storage class. It is the responsibility of the caller to make sure the
471+
// decorations on the base type are valid for the given storage class. For
472+
// example, it has the correct offset and stride decorations.
473+
SPIRVType *
474+
getOrCreateSPIRVPointerTypeInternal(SPIRVType *BaseType,
475+
MachineIRBuilder &MIRBuilder,
476+
SPIRV::StorageClass::StorageClass SC);
477+
469478
public:
470479
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
471480
SPIRVType *SpvType, bool EmitIR,
@@ -540,12 +549,30 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
540549
unsigned NumElements, MachineInstr &I,
541550
const SPIRVInstrInfo &TII);
542551

543-
SPIRVType *getOrCreateSPIRVPointerType(
544-
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
545-
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
546-
SPIRVType *getOrCreateSPIRVPointerType(
547-
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
548-
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
552+
// Returns a pointer to a SPIR-V pointer type with the given base type and
553+
// storage class. The base type will be translated to a SPIR-V type, and the
554+
// appropriate layout decorations will be added to the base type.
555+
SPIRVType *getOrCreateSPIRVPointerType(const Type *BaseType,
556+
MachineIRBuilder &MIRBuilder,
557+
SPIRV::StorageClass::StorageClass SC);
558+
SPIRVType *getOrCreateSPIRVPointerType(const Type *BaseType, MachineInstr &I,
559+
SPIRV::StorageClass::StorageClass SC);
560+
561+
// Returns a pointer to a SPIR-V pointer type with the given base type and
562+
// storage class. It is the responsibility of the caller to make sure the
563+
// decorations on the base type are valid for the given storage class. For
564+
// example, it has the correct offset and stride decorations.
565+
SPIRVType *getOrCreateSPIRVPointerType(SPIRVType *BaseType,
566+
MachineIRBuilder &MIRBuilder,
567+
SPIRV::StorageClass::StorageClass SC);
568+
569+
// Returns a pointer to a SPIR-V pointer type that is the same as `PtrType`
570+
// except the stroage class has been changed to `SC`. It is the responsibility
571+
// of the caller to be sure that the original and new storage class have the
572+
// same layout requirements.
573+
SPIRVType *changePointerStorageClass(SPIRVType *PtrType,
574+
SPIRV::StorageClass::StorageClass SC,
575+
MachineInstr &I);
549576

550577
SPIRVType *getOrCreateVulkanBufferType(MachineIRBuilder &MIRBuilder,
551578
Type *ElemType,

Diff for: llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
214214
PtrType->getOperand(1).getImm());
215215
MachineIRBuilder MIB(I);
216216
LLVMContext &Context = MF->getFunction().getContext();
217-
SPIRVType *ElemType =
218-
GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
219-
SPIRV::AccessQualifier::ReadWrite, false);
220-
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
217+
SPIRVType *NewPtrType =
218+
GR.getOrCreateSPIRVPointerType(IntegerType::getInt8Ty(Context), MIB, SC);
221219
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
222220
}
223221

Diff for: llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

+21-27
Original file line numberDiff line numberDiff line change
@@ -1257,14 +1257,18 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
12571257
Register SrcReg = I.getOperand(1).getReg();
12581258
bool Result = true;
12591259
if (I.getOpcode() == TargetOpcode::G_MEMSET) {
1260+
MachineIRBuilder MIRBuilder(I);
12601261
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
12611262
unsigned Val = getIConstVal(I.getOperand(1).getReg(), MRI);
12621263
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
1263-
SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
1264-
SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
1265-
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
1264+
Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
1265+
Type *ArrTy = ArrayType::get(ValTy, Num);
12661266
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
1267-
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
1267+
ArrTy, MIRBuilder, SPIRV::StorageClass::UniformConstant);
1268+
1269+
SPIRVType *SpvArrTy = GR.getOrCreateSPIRVType(
1270+
ArrTy, MIRBuilder, SPIRV::AccessQualifier::None, false);
1271+
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, SpvArrTy, TII);
12681272
// TODO: check if we have such GV, add init, use buildGlobalVariable.
12691273
Function &CurFunction = GR.CurMF->getFunction();
12701274
Type *LLVMArrTy =
@@ -1287,7 +1291,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
12871291

12881292
buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
12891293
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
1290-
ValTy, I, TII, SPIRV::StorageClass::UniformConstant);
1294+
ValTy, I, SPIRV::StorageClass::UniformConstant);
12911295
SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
12921296
selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
12931297
}
@@ -1588,7 +1592,7 @@ static bool isASCastInGVar(MachineRegisterInfo *MRI, Register ResVReg) {
15881592
Register SPIRVInstructionSelector::getUcharPtrTypeReg(
15891593
MachineInstr &I, SPIRV::StorageClass::StorageClass SC) const {
15901594
return GR.getSPIRVTypeID(GR.getOrCreateSPIRVPointerType(
1591-
GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, SC));
1595+
Type::getInt8Ty(I.getMF()->getFunction().getContext()), I, SC));
15921596
}
15931597

15941598
MachineInstrBuilder
@@ -1606,8 +1610,8 @@ SPIRVInstructionSelector::buildSpecConstantOp(MachineInstr &I, Register Dest,
16061610
MachineInstrBuilder
16071611
SPIRVInstructionSelector::buildConstGenericPtr(MachineInstr &I, Register SrcPtr,
16081612
SPIRVType *SrcPtrTy) const {
1609-
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
1610-
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1613+
SPIRVType *GenericPtrTy =
1614+
GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
16111615
Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
16121616
MRI->setType(Tmp, LLT::pointer(storageClassToAddressSpace(
16131617
SPIRV::StorageClass::Generic),
@@ -1692,8 +1696,8 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
16921696
return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
16931697
// Casting between 2 eligible pointers using Generic as an intermediary.
16941698
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
1695-
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
1696-
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1699+
SPIRVType *GenericPtrTy =
1700+
GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
16971701
Register Tmp = createVirtualRegister(GenericPtrTy, &GR, MRI, MRI->getMF());
16981702
bool Result = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
16991703
.addDef(Tmp)
@@ -3364,18 +3368,20 @@ bool SPIRVInstructionSelector::selectImageWriteIntrinsic(
33643368
}
33653369

33663370
Register SPIRVInstructionSelector::buildPointerToResource(
3367-
const SPIRVType *ResType, SPIRV::StorageClass::StorageClass SC,
3371+
const SPIRVType *SpirvResType, SPIRV::StorageClass::StorageClass SC,
33683372
uint32_t Set, uint32_t Binding, uint32_t ArraySize, Register IndexReg,
33693373
bool IsNonUniform, MachineIRBuilder MIRBuilder) const {
3374+
const Type *ResType = GR.getTypeForSPIRVType(SpirvResType);
33703375
if (ArraySize == 1) {
33713376
SPIRVType *PtrType =
33723377
GR.getOrCreateSPIRVPointerType(ResType, MIRBuilder, SC);
3378+
assert(GR.getPointeeType(PtrType) == SpirvResType &&
3379+
"SpirvResType did not have an explicit layout.");
33733380
return GR.getOrCreateGlobalVariableWithBinding(PtrType, Set, Binding,
33743381
MIRBuilder);
33753382
}
33763383

3377-
const SPIRVType *VarType = GR.getOrCreateSPIRVArrayType(
3378-
ResType, ArraySize, *MIRBuilder.getInsertPt(), TII);
3384+
const Type *VarType = ArrayType::get(const_cast<Type *>(ResType), ArraySize);
33793385
SPIRVType *VarPointerType =
33803386
GR.getOrCreateSPIRVPointerType(VarType, MIRBuilder, SC);
33813387
Register VarReg = GR.getOrCreateGlobalVariableWithBinding(
@@ -3805,17 +3811,6 @@ bool SPIRVInstructionSelector::selectGlobalValue(
38053811
MachineIRBuilder MIRBuilder(I);
38063812
const GlobalValue *GV = I.getOperand(1).getGlobal();
38073813
Type *GVType = toTypedPointer(GR.getDeducedGlobalValueType(GV));
3808-
SPIRVType *PointerBaseType;
3809-
if (GVType->isArrayTy()) {
3810-
SPIRVType *ArrayElementType =
3811-
GR.getOrCreateSPIRVType(GVType->getArrayElementType(), MIRBuilder,
3812-
SPIRV::AccessQualifier::ReadWrite, false);
3813-
PointerBaseType = GR.getOrCreateSPIRVArrayType(
3814-
ArrayElementType, GVType->getArrayNumElements(), I, TII);
3815-
} else {
3816-
PointerBaseType = GR.getOrCreateSPIRVType(
3817-
GVType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false);
3818-
}
38193814

38203815
std::string GlobalIdent;
38213816
if (!GV->hasName()) {
@@ -3848,7 +3843,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
38483843
? dyn_cast<Function>(GV)
38493844
: nullptr;
38503845
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
3851-
PointerBaseType, I, TII,
3846+
GVType, I,
38523847
GVFun ? SPIRV::StorageClass::CodeSectionINTEL
38533848
: addressSpaceToStorageClass(GV->getAddressSpace(), STI));
38543849
if (GVFun) {
@@ -3906,8 +3901,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
39063901
const unsigned AddrSpace = GV->getAddressSpace();
39073902
SPIRV::StorageClass::StorageClass StorageClass =
39083903
addressSpaceToStorageClass(AddrSpace, STI);
3909-
SPIRVType *ResType =
3910-
GR.getOrCreateSPIRVPointerType(PointerBaseType, I, TII, StorageClass);
3904+
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass);
39113905
Register Reg = GR.buildGlobalVariable(
39123906
ResVReg, ResType, GlobalIdent, GV, StorageClass, Init,
39133907
GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true);

Diff for: llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,8 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
251251
Register Def = MI.getOperand(0).getReg();
252252
Register Source = MI.getOperand(2).getReg();
253253
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
254-
SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
255-
ElemTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
256254
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
257-
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
255+
ElemTy, MI,
258256
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
259257

260258
// If the ptrcast would be redundant, replace all uses with the source
@@ -366,9 +364,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
366364
RegType.getAddressSpace()) {
367365
const SPIRVSubtarget &ST =
368366
MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
369-
SpvType = GR->getOrCreateSPIRVPointerType(
370-
GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
371-
addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
367+
auto TSC = addressSpaceToStorageClass(RegType.getAddressSpace(), ST);
368+
SpvType = GR->changePointerStorageClass(SpvType, TSC, *MI);
372369
}
373370
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
374371
}
@@ -518,10 +515,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
518515
Register Reg = MI.getOperand(1).getReg();
519516
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
520517
Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
521-
SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
522-
ElementTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
523518
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
524-
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
519+
ElementTy, MI,
525520
addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
526521
MachineInstr *Def = MRI.getVRegDef(Reg);
527522
assert(Def && "Expecting an instruction that defines the register");

0 commit comments

Comments
 (0)