Skip to content

Commit c189be1

Browse files
committed
Modified vgpr/sgpr occupancy limit method at the Global Write stage.
* Add refineOccupancy to set the proper occupancy limit. * Allocated the required SGPRs needed by storeRemapAddStore() before calculating occupancy.
1 parent 45f9712 commit c189be1

File tree

1 file changed

+115
-82
lines changed

1 file changed

+115
-82
lines changed

tensilelite/Tensile/KernelWriterAssembly.py

Lines changed: 115 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10253,6 +10253,8 @@ def cleanupGlobalWrite(self, kernel):
1025310253
self.vgprPool.checkIn(self.vgprs.storeRemapCoord0)
1025410254
self.vgprPool.checkIn(self.vgprs.storeRemapCoord1)
1025510255
self.vgprPool.checkIn(self.vgprs.storeRemapOffsetCoord1)
10256+
for v in self.vgprs.storeRemapAS:
10257+
self.vgprPool.checkIn(v)
1025610258
if kernel["BufferStore"]:
1025710259
self.vgprPool.checkIn(self.vgprs.cinRowPtr)
1025810260
self.vgprPool.checkIn(self.vgprs.coutRowPtrD)
@@ -10374,10 +10376,7 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
1037410376
rpe = self.states.bpeCexternal / self.states.bpr
1037510377
rpv = rpe * gwvw
1037610378

10377-
# num registers to check out
10378-
storeRegs = []
10379-
for i in range(0, nElements, gwvw):
10380-
storeRegs.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10379+
storeRegs = self.vgprs.storeRemapAS
1038110380
src = vgpr(self.vgprs.storeRemapLR)
1038210381
for rIdx, i in enumerate(range(0, nElements, gwvw)):
1038310382
offset = self.storeRemapLrOffset * bpe * (i//gwvw)
@@ -10482,8 +10481,6 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
1048210481

1048310482
module.addSpaceLine()
1048410483
self.vgprPool.checkIn(vTmp)
10485-
for v in storeRegs:
10486-
self.vgprPool.checkIn(v)
1048710484

1048810485
#Data exchange between different waves
1048910486
#Make sure LDS reads are finished of all waves
@@ -10628,6 +10625,12 @@ def storeRemapComputeStoreVgprs(self, kernel):
1062810625

1062910626
self.vgprPool.checkIn(tmpV0)
1063010627

10628+
nElements = kernel["MacroTile0"]*kernel["MatrixInstN"]//kernel["MIWaveGroup"][0]//self.states.kernel["WavefrontSize"]
10629+
rpe = self.states.bpeCexternal / self.states.bpr
10630+
rpv = rpe * gwvw
10631+
self.vgprs.storeRemapAS = []
10632+
for i in range(0, nElements, gwvw):
10633+
self.vgprs.storeRemapAS.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
1063110634
return module
1063210635

1063310636
##############################################################################
@@ -11620,77 +11623,19 @@ def getMBSKGSUTotal(self, kernel):
1162011623
GSUtotal = max(2,GSUtotal)
1162111624
return GSUtotal
1162211625

11623-
##############################################################################
11624-
# globalWriteElementBatch :
11625-
##############################################################################
11626-
def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11627-
applyAlpha, beta, edge, atomic, \
11628-
vectorWidths, elements, activationLabelList, \
11629-
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11630-
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11631-
edgeModule, writeLabels, endLabel, \
11632-
edge_mode_pos, currentInstLength, \
11633-
idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11634-
factorDim = factorDims[idx2]
11635-
edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11636-
if idx2 == 0:
11637-
edge_mode_pos = len(edgeModule.items())
11638-
11639-
# for storeRemap edge case, non-beta still can enable vector stores
11640-
if kernel["StoreRemapVectorWidth"] and not beta:
11641-
edgeI = False
11642-
else:
11643-
edgeI = edge
11644-
#edgeI = True # set to True to disable vector stores
11645-
gwvw = vectorWidths[edgeI]
11646-
11647-
#print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11648-
11649-
########################################
11650-
# Calculate Vgprs for Write Batching
11651-
########################################
11652-
self.vgprPool.resetOccupancyLimit()
11653-
self.sgprPool.resetOccupancyLimit()
11654-
11655-
# Temporarily grow pool for sgpr
11656-
sgprList = []
11657-
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11658-
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11659-
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11660-
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11661-
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11662-
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11663-
sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11664-
for s in sgprList:
11665-
self.sgprPool.checkIn(s)
11666-
if actPCMaxTempSgpr > 0:
11667-
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11668-
11669-
tmpVgprDynamic = None
11670-
tmpVgprDynamicSize = 0
11671-
tmpVgprDynamicAlign = 0
11672-
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11673-
GSUTotal = self.getMBSKGSUTotal(kernel)
11674-
vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11675-
tmpVgprDynamicSize = vgprMbsk
11676-
tmpVgprDynamicAlign = 4
11677-
if tmpVgprDynamicSize > 0:
11678-
tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11679-
11680-
ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11681-
11682-
def setOccupancy():
11683-
# Use VGPR up to next occupancy threshold:
11684-
maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11685-
self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11686-
# Set occupancy limit for register pools
11687-
# TODO: Support gfx12
11688-
if kernel["ISA"][0] != 12:
11689-
self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11690-
self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11691-
return maxVgprs, occupancy
11692-
11693-
maxVgprs, occupancy = setOccupancy()
11626+
def setOccupancy(self, kernel):
11627+
# Use VGPR up to next occupancy threshold:
11628+
maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11629+
self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11630+
# Set occupancy limit for register pools
11631+
# TODO: Support gfx12
11632+
if kernel["ISA"][0] != 12:
11633+
self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11634+
self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11635+
return maxVgprs, occupancy
11636+
11637+
def refineOccupancy(self, kernel, atomic, elements, actPCMaxTempSgpr, \
11638+
edgeI, gwvw, maxVgprs, ss):
1169411639
# Get estimated numVgprAvailable
1169511640
# print("Max vgprs =", maxVgprs, self.vgprPool.size(), self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align))
1169611641
numVgprAvailable = self.vgprPool.availableBlockMaxVgpr(maxVgprs, ss.numVgprsPerElement, ss.align)
@@ -11711,12 +11656,12 @@ def setOccupancy():
1171111656
% (minElements,ss.numVgprsPerElement))
1171211657
self.vgprPool.growPool(0, minElements, ss.numVgprsPerElement, \
1171311658
"grow-pool for GlobalWrite")
11714-
maxVgprs, occupancy = setOccupancy()
11659+
maxVgprsN, occupancy = self.setOccupancy(kernel)
11660+
if maxVgprs != maxVgprsN:
11661+
#print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11662+
return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
1171511663
numVgprAvailable = self.vgprPool.available()
1171611664

11717-
# set atomicW after we potentially resize GWVW
11718-
atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11719-
1172011665
# print("NumVgprAvailable", numVgprAvailable)
1172111666
if ss.numVgprsPerElement:
1172211667
numElementsPerBatch = numVgprAvailable // ss.numVgprsPerElement
@@ -11761,6 +11706,7 @@ def setOccupancy():
1176111706

1176211707
# check best numElementsPerBatch to handle a column block
1176311708
# elements of column block must be multiple size of numElementsPerBatch
11709+
nBatchesPerRow = 0
1176411710
if kernel["StoreRemapVectorWidth"]:
1176511711
firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0)
1176611712
# find the largest factor and smaller than numElementPerBatch
@@ -11785,19 +11731,106 @@ def setOccupancy():
1178511731
totalNeededVgpr = ss.numVgprsPerElement * numElementsPerBatch
1178611732
# print("Available vgprs =", numVgprAvailable, "Needed vgprs =", totalNeededVgpr, "pool size =", self.vgprPool.size())
1178711733
if numVgprAvailable < totalNeededVgpr:
11734+
self.vgprPool.resetOccupancyLimit()
1178811735
print2("info: growing pool += %d * %d for GlobalWrite\n" \
1178911736
% (numBatches,ss.numVgprsPerElement))
1179011737
availableBlock = min(0, self.vgprPool.available() - numVgprAvailable)
1179111738
self.vgprPool.growPool(0, totalNeededVgpr + availableBlock, 1, "grow-pool for GlobalWrite")
11739+
maxVgprsN, occupancy = self.setOccupancy(kernel)
11740+
if maxVgprs != maxVgprsN:
11741+
#print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11742+
return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11743+
1179211744
# # Get true numVgprAvailable
1179311745
# numVgprAvailable = self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align)
1179411746
# print("Available vgprs =", numVgprAvailable, "pool size =", self.vgprPool.size())
1179511747

1179611748
numSgprs = ss.cfg.numTempSgprPerBatch + ss.cfg.numMaskSgprPerBatch + ss.cfg.numMaskSgprPerElement * numElementsPerBatch
1179711749

11750+
if actPCMaxTempSgpr:
11751+
numSgprs = max(actPCMaxTempSgpr, numSgprs)
11752+
11753+
self.sgprPool.resetOccupancyLimit()
11754+
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(numSgprs, 2, preventOverflow=False))
11755+
maxVgprsN, occupancy = self.setOccupancy(kernel)
11756+
if maxVgprs != maxVgprsN:
11757+
#print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11758+
return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11759+
return numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs
11760+
11761+
11762+
##############################################################################
11763+
# globalWriteElementBatch :
11764+
##############################################################################
11765+
def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11766+
applyAlpha, beta, edge, atomic, \
11767+
vectorWidths, elements, activationLabelList, \
11768+
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11769+
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11770+
edgeModule, writeLabels, endLabel, \
11771+
edge_mode_pos, currentInstLength, \
11772+
idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11773+
factorDim = factorDims[idx2]
11774+
edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11775+
if idx2 == 0:
11776+
edge_mode_pos = len(edgeModule.items())
11777+
11778+
# for storeRemap edge case, non-beta still can enable vector stores
11779+
if kernel["StoreRemapVectorWidth"] and not beta:
11780+
edgeI = False
11781+
else:
11782+
edgeI = edge
11783+
#edgeI = True # set to True to disable vector stores
11784+
gwvw = vectorWidths[edgeI]
11785+
11786+
#print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11787+
11788+
########################################
11789+
# Calculate Vgprs for Write Batching
11790+
########################################
11791+
self.vgprPool.resetOccupancyLimit()
11792+
self.sgprPool.resetOccupancyLimit()
11793+
11794+
# Temporarily grow pool for sgpr
11795+
sgprList = []
11796+
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11797+
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11798+
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11799+
sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11800+
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11801+
sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11802+
sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11803+
for s in sgprList:
11804+
self.sgprPool.checkIn(s)
11805+
if actPCMaxTempSgpr > 0:
11806+
self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11807+
11808+
tmpVgprDynamic = None
11809+
tmpVgprDynamicSize = 0
11810+
tmpVgprDynamicAlign = 0
11811+
if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11812+
GSUTotal = self.getMBSKGSUTotal(kernel)
11813+
vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11814+
tmpVgprDynamicSize = vgprMbsk
11815+
tmpVgprDynamicAlign = 4
11816+
if tmpVgprDynamicSize > 0:
11817+
tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11818+
11819+
maxVgprs, occupancy = self.setOccupancy(kernel)
11820+
11821+
ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11822+
11823+
actPCMaxTempSgpr_ = None
1179811824
if activationLabelList and isInsertActFunctionCallAddrCalc:
1179911825
assert activationSetPCStruct, activationEnumStrList and activationLabelList and toActModuleList
11800-
numSgprs = max(actPCMaxTempSgpr, numSgprs)
11826+
actPCMaxTempSgpr_ = actPCMaxTempSgpr
11827+
11828+
numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs = self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr_, edgeI, gwvw, maxVgprs, ss)
11829+
11830+
# set atomicW after we potentially resize GWVW
11831+
atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11832+
11833+
if activationLabelList and isInsertActFunctionCallAddrCalc:
1180111834
edgeModule.add(self.insertActFunctionCallAddrCalc(activationSetPCStruct.sgprOffsetActivation, \
1180211835
gwvw, toActModuleList, activationEnumStrList, activationLabelList, \
1180311836
idx0, idx1))

0 commit comments

Comments
 (0)