Skip to content

Commit c1c8201

Browse files
committed
Modified vgpr/sgpr occupancy limit method at the Global Write stage.
* Add refineOccupancy to set the proper occupancy limit. * Allocated the required SGPRs needed by storeRemapAddStore() before calculating occupancy.
1 parent efe1ced commit c1c8201

File tree

1 file changed

+115
-82
lines changed

1 file changed

+115
-82
lines changed

tensilelite/Tensile/KernelWriterAssembly.py

Lines changed: 115 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10218,6 +10218,8 @@ def cleanupGlobalWrite(self, kernel):
1021810218
self.vgprPool.checkIn(self.vgprs.storeRemapCoord0)
1021910219
self.vgprPool.checkIn(self.vgprs.storeRemapCoord1)
1022010220
self.vgprPool.checkIn(self.vgprs.storeRemapOffsetCoord1)
10221+
for v in self.vgprs.storeRemapAS:
10222+
self.vgprPool.checkIn(v)
1022110223
if kernel["BufferStore"]:
1022210224
self.vgprPool.checkIn(self.vgprs.cinRowPtr)
1022310225
self.vgprPool.checkIn(self.vgprs.coutRowPtrD)
@@ -10339,10 +10341,7 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
1033910341
rpe = self.states.bpeCexternal / self.states.bpr
1034010342
rpv = rpe * gwvw
1034110343

10342-
# num registers to check out
10343-
storeRegs = []
10344-
for i in range(0, nElements, gwvw):
10345-
storeRegs.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10344+
storeRegs = self.vgprs.storeRemapAS
1034610345
src = vgpr(self.vgprs.storeRemapLR)
1034710346
for rIdx, i in enumerate(range(0, nElements, gwvw)):
1034810347
offset = self.storeRemapLrOffset * bpe * (i//gwvw)
@@ -10447,8 +10446,6 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
1044710446

1044810447
module.addSpaceLine()
1044910448
self.vgprPool.checkIn(vTmp)
10450-
for v in storeRegs:
10451-
self.vgprPool.checkIn(v)
1045210449

1045310450
#Data exchange between different waves
1045410451
#Make sure LDS reads are finished of all waves
@@ -10593,6 +10590,12 @@ def storeRemapComputeStoreVgprs(self, kernel):
1059310590

1059410591
self.vgprPool.checkIn(tmpV0)
1059510592

10593+
nElements = kernel["MacroTile0"]*kernel["MatrixInstN"]//kernel["MIWaveGroup"][0]//self.states.kernel["WavefrontSize"]
10594+
rpe = self.states.bpeCexternal / self.states.bpr
10595+
rpv = rpe * gwvw
10596+
self.vgprs.storeRemapAS = []
10597+
for i in range(0, nElements, gwvw):
10598+
self.vgprs.storeRemapAS.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
1059610599
return module
1059710600

1059810601
##############################################################################
@@ -11585,77 +11588,19 @@ def getMBSKGSUTotal(self, kernel):
1158511588
GSUtotal = max(2,GSUtotal)
1158611589
return GSUtotal
1158711590

11588-
##############################################################################
11589-
# globalWriteElementBatch :
11590-
##############################################################################
11591-
def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11592-
applyAlpha, beta, edge, atomic, \
11593-
vectorWidths, elements, activationLabelList, \
11594-
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11595-
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11596-
edgeModule, writeLabels, endLabel, \
11597-
edge_mode_pos, currentInstLength, \
11598-
idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11599-
factorDim = factorDims[idx2]
11600-
edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11601-
if idx2 == 0:
11602-
edge_mode_pos = len(edgeModule.items())
11603-
11604-
# for storeRemap edge case, non-beta still can enable vector stores
11605-
if kernel["StoreRemapVectorWidth"] and not beta:
11606-
edgeI = False
11607-
else:
11608-
edgeI = edge
11609-
#edgeI = True # set to True to disable vector stores
11610-
gwvw = vectorWidths[edgeI]
11611-
11612-
#print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11613-
11614-
########################################
11615-
# Calculate Vgprs for Write Batching
11616-
########################################
11617-
self.vgprPool.resetOccupancyLimit()
11618-
self.sgprPool.resetOccupancyLimit()
11619-
11620-
# Temporarily grow pool for sgpr
11621-
sgprList = []
11622-
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11623-
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11624-
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11625-
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11626-
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11627-
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11628-
sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11629-
for s in sgprList:
11630-
self.sgprPool.checkIn(s)
11631-
if actPCMaxTempSgpr > 0:
11632-
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11633-
11634-
tmpVgprDynamic = None
11635-
tmpVgprDynamicSize = 0
11636-
tmpVgprDynamicAlign = 0
11637-
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11638-
GSUTotal = self.getMBSKGSUTotal(kernel)
11639-
vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11640-
tmpVgprDynamicSize = vgprMbsk
11641-
tmpVgprDynamicAlign = 4
11642-
if tmpVgprDynamicSize > 0:
11643-
tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11644-
11645-
ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11646-
11647-
def setOccupancy():
11648-
# Use VGPR up to next occupancy threshold:
11649-
maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11650-
self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11651-
# Set occupancy limit for register pools
11652-
# TODO: Support gfx12
11653-
if kernel["ISA"][0] != 12:
11654-
self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11655-
self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11656-
return maxVgprs, occupancy
11657-
11658-
maxVgprs, occupancy = setOccupancy()
11591+
def setOccupancy(self, kernel):
11592+
# Use VGPR up to next occupancy threshold:
11593+
maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11594+
self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11595+
# Set occupancy limit for register pools
11596+
# TODO: Support gfx12
11597+
if kernel["ISA"][0] != 12:
11598+
self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11599+
self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11600+
return maxVgprs, occupancy
11601+
11602+
def refineOccupancy(self, kernel, atomic, elements, actPCMaxTempSgpr, \
11603+
edgeI, gwvw, maxVgprs, ss):
1165911604
# Get estimated numVgprAvailable
1166011605
# print("Max vgprs =", maxVgprs, self.vgprPool.size(), self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align))
1166111606
numVgprAvailable = self.vgprPool.availableBlockMaxVgpr(maxVgprs, ss.numVgprsPerElement, ss.align)
@@ -11676,12 +11621,12 @@ def setOccupancy():
1167611621
% (minElements,ss.numVgprsPerElement))
1167711622
self.vgprPool.growPool(0, minElements, ss.numVgprsPerElement, \
1167811623
"grow-pool for GlobalWrite")
11679-
maxVgprs, occupancy = setOccupancy()
11624+
maxVgprsN, occupancy = self.setOccupancy(kernel)
11625+
if maxVgprs != maxVgprsN:
11626+
#print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11627+
return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
1168011628
numVgprAvailable = self.vgprPool.available()
1168111629

11682-
# set atomicW after we potentially resize GWVW
11683-
atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11684-
1168511630
# print("NumVgprAvailable", numVgprAvailable)
1168611631
if ss.numVgprsPerElement:
1168711632
numElementsPerBatch = numVgprAvailable // ss.numVgprsPerElement
@@ -11726,6 +11671,7 @@ def setOccupancy():
1172611671

1172711672
# check best numElementsPerBatch to handle a column block
1172811673
# elements of column block must be multiple size of numElementsPerBatch
11674+
nBatchesPerRow = 0
1172911675
if kernel["StoreRemapVectorWidth"]:
1173011676
firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0)
1173111677
# find the largest factor and smaller than numElementPerBatch
@@ -11750,19 +11696,106 @@ def setOccupancy():
1175011696
totalNeededVgpr = ss.numVgprsPerElement * numElementsPerBatch
1175111697
# print("Available vgprs =", numVgprAvailable, "Needed vgprs =", totalNeededVgpr, "pool size =", self.vgprPool.size())
1175211698
if numVgprAvailable < totalNeededVgpr:
11699+
self.vgprPool.resetOccupancyLimit()
1175311700
print2("info: growing pool += %d * %d for GlobalWrite\n" \
1175411701
% (numBatches,ss.numVgprsPerElement))
1175511702
availableBlock = min(0, self.vgprPool.available() - numVgprAvailable)
1175611703
self.vgprPool.growPool(0, totalNeededVgpr + availableBlock, 1, "grow-pool for GlobalWrite")
11704+
maxVgprsN, occupancy = self.setOccupancy(kernel)
11705+
if maxVgprs != maxVgprsN:
11706+
#print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11707+
return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11708+
1175711709
# # Get true numVgprAvailable
1175811710
# numVgprAvailable = self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align)
1175911711
# print("Available vgprs =", numVgprAvailable, "pool size =", self.vgprPool.size())
1176011712

1176111713
numSgprs = ss.cfg.numTempSgprPerBatch + ss.cfg.numMaskSgprPerBatch + ss.cfg.numMaskSgprPerElement * numElementsPerBatch
1176211714

11715+
if actPCMaxTempSgpr:
11716+
numSgprs = max(actPCMaxTempSgpr, numSgprs)
11717+
11718+
self.sgprPool.resetOccupancyLimit()
11719+
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(numSgprs, 2, preventOverflow=False))
11720+
maxVgprsN, occupancy = self.setOccupancy(kernel)
11721+
if maxVgprs != maxVgprsN:
11722+
#print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11723+
return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11724+
return numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs
11725+
11726+
11727+
##############################################################################
11728+
# globalWriteElementBatch :
11729+
##############################################################################
11730+
def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11731+
applyAlpha, beta, edge, atomic, \
11732+
vectorWidths, elements, activationLabelList, \
11733+
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11734+
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11735+
edgeModule, writeLabels, endLabel, \
11736+
edge_mode_pos, currentInstLength, \
11737+
idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11738+
factorDim = factorDims[idx2]
11739+
edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11740+
if idx2 == 0:
11741+
edge_mode_pos = len(edgeModule.items())
11742+
11743+
# for storeRemap edge case, non-beta still can enable vector stores
11744+
if kernel["StoreRemapVectorWidth"] and not beta:
11745+
edgeI = False
11746+
else:
11747+
edgeI = edge
11748+
#edgeI = True # set to True to disable vector stores
11749+
gwvw = vectorWidths[edgeI]
11750+
11751+
#print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11752+
11753+
########################################
11754+
# Calculate Vgprs for Write Batching
11755+
########################################
11756+
self.vgprPool.resetOccupancyLimit()
11757+
self.sgprPool.resetOccupancyLimit()
11758+
11759+
# Temporarily grow pool for sgpr
11760+
sgprList = []
11761+
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11762+
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11763+
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11764+
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11765+
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11766+
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11767+
sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11768+
for s in sgprList:
11769+
self.sgprPool.checkIn(s)
11770+
if actPCMaxTempSgpr > 0:
11771+
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11772+
11773+
tmpVgprDynamic = None
11774+
tmpVgprDynamicSize = 0
11775+
tmpVgprDynamicAlign = 0
11776+
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11777+
GSUTotal = self.getMBSKGSUTotal(kernel)
11778+
vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11779+
tmpVgprDynamicSize = vgprMbsk
11780+
tmpVgprDynamicAlign = 4
11781+
if tmpVgprDynamicSize > 0:
11782+
tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11783+
11784+
maxVgprs, occupancy = self.setOccupancy(kernel)
11785+
11786+
ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11787+
11788+
actPCMaxTempSgpr_ = None
1176311789
if activationLabelList and isInsertActFunctionCallAddrCalc:
1176411790
assert activationSetPCStruct, activationEnumStrList and activationLabelList and toActModuleList
11765-
numSgprs = max(actPCMaxTempSgpr, numSgprs)
11791+
actPCMaxTempSgpr_ = actPCMaxTempSgpr
11792+
11793+
numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs = self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr_, edgeI, gwvw, maxVgprs, ss)
11794+
11795+
# set atomicW after we potentially resize GWVW
11796+
atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11797+
11798+
if activationLabelList and isInsertActFunctionCallAddrCalc:
1176611799
edgeModule.add(self.insertActFunctionCallAddrCalc(activationSetPCStruct.sgprOffsetActivation, \
1176711800
gwvw, toActModuleList, activationEnumStrList, activationLabelList, \
1176811801
idx0, idx1))

0 commit comments

Comments
 (0)