@@ -10253,6 +10253,8 @@ def cleanupGlobalWrite(self, kernel):
1025310253 self.vgprPool.checkIn(self.vgprs.storeRemapCoord0)
1025410254 self.vgprPool.checkIn(self.vgprs.storeRemapCoord1)
1025510255 self.vgprPool.checkIn(self.vgprs.storeRemapOffsetCoord1)
10256+ for v in self.vgprs.storeRemapAS:
10257+ self.vgprPool.checkIn(v)
1025610258 if kernel["BufferStore"]:
1025710259 self.vgprPool.checkIn(self.vgprs.cinRowPtr)
1025810260 self.vgprPool.checkIn(self.vgprs.coutRowPtrD)
@@ -10374,10 +10376,7 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
1037410376 rpe = self.states.bpeCexternal / self.states.bpr
1037510377 rpv = rpe * gwvw
1037610378
10377- # num registers to check out
10378- storeRegs = []
10379- for i in range(0, nElements, gwvw):
10380- storeRegs.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10379+ storeRegs = self.vgprs.storeRemapAS
1038110380 src = vgpr(self.vgprs.storeRemapLR)
1038210381 for rIdx, i in enumerate(range(0, nElements, gwvw)):
1038310382 offset = self.storeRemapLrOffset * bpe * (i//gwvw)
@@ -10482,8 +10481,6 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
1048210481
1048310482 module.addSpaceLine()
1048410483 self.vgprPool.checkIn(vTmp)
10485- for v in storeRegs:
10486- self.vgprPool.checkIn(v)
1048710484
1048810485 #Data exchange between different waves
1048910486 #Make sure LDS reads are finished of all waves
@@ -10628,6 +10625,12 @@ def storeRemapComputeStoreVgprs(self, kernel):
1062810625
1062910626 self.vgprPool.checkIn(tmpV0)
1063010627
10628+ nElements = kernel["MacroTile0"]*kernel["MatrixInstN"]//kernel["MIWaveGroup"][0]//self.states.kernel["WavefrontSize"]
10629+ rpe = self.states.bpeCexternal / self.states.bpr
10630+ rpv = rpe * gwvw
10631+ self.vgprs.storeRemapAS = []
10632+ for i in range(0, nElements, gwvw):
10633+ self.vgprs.storeRemapAS.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
1063110634 return module
1063210635
1063310636 ##############################################################################
@@ -11620,77 +11623,19 @@ def getMBSKGSUTotal(self, kernel):
1162011623 GSUtotal = max(2,GSUtotal)
1162111624 return GSUtotal
1162211625
11623- ##############################################################################
11624- # globalWriteElementBatch :
11625- ##############################################################################
11626- def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11627- applyAlpha, beta, edge, atomic, \
11628- vectorWidths, elements, activationLabelList, \
11629- tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11630- actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11631- edgeModule, writeLabels, endLabel, \
11632- edge_mode_pos, currentInstLength, \
11633- idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11634- factorDim = factorDims[idx2]
11635- edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11636- if idx2 == 0:
11637- edge_mode_pos = len(edgeModule.items())
11638-
11639- # for storeRemap edge case, non-beta still can enable vector stores
11640- if kernel["StoreRemapVectorWidth"] and not beta:
11641- edgeI = False
11642- else:
11643- edgeI = edge
11644- #edgeI = True # set to True to disable vector stores
11645- gwvw = vectorWidths[edgeI]
11646-
11647- #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11648-
11649- ########################################
11650- # Calculate Vgprs for Write Batching
11651- ########################################
11652- self.vgprPool.resetOccupancyLimit()
11653- self.sgprPool.resetOccupancyLimit()
11654-
11655- # Temporarily grow pool for sgpr
11656- sgprList = []
11657- if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11658- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11659- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11660- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11661- sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11662- sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11663- sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11664- for s in sgprList:
11665- self.sgprPool.checkIn(s)
11666- if actPCMaxTempSgpr > 0:
11667- self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11668-
11669- tmpVgprDynamic = None
11670- tmpVgprDynamicSize = 0
11671- tmpVgprDynamicAlign = 0
11672- if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11673- GSUTotal = self.getMBSKGSUTotal(kernel)
11674- vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11675- tmpVgprDynamicSize = vgprMbsk
11676- tmpVgprDynamicAlign = 4
11677- if tmpVgprDynamicSize > 0:
11678- tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11679-
11680- ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11681-
11682- def setOccupancy():
11683- # Use VGPR up to next occupancy threshold:
11684- maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11685- self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11686- # Set occupancy limit for register pools
11687- # TODO: Support gfx12
11688- if kernel["ISA"][0] != 12:
11689- self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11690- self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11691- return maxVgprs, occupancy
11692-
11693- maxVgprs, occupancy = setOccupancy()
11626+ def setOccupancy(self, kernel):
11627+ # Use VGPR up to next occupancy threshold:
11628+ maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11629+ self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11630+ # Set occupancy limit for register pools
11631+ # TODO: Support gfx12
11632+ if kernel["ISA"][0] != 12:
11633+ self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11634+ self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11635+ return maxVgprs, occupancy
11636+
11637+ def refineOccupancy(self, kernel, atomic, elements, actPCMaxTempSgpr, \
11638+ edgeI, gwvw, maxVgprs, ss):
1169411639 # Get estimated numVgprAvailable
1169511640 # print("Max vgprs =", maxVgprs, self.vgprPool.size(), self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align))
1169611641 numVgprAvailable = self.vgprPool.availableBlockMaxVgpr(maxVgprs, ss.numVgprsPerElement, ss.align)
@@ -11711,12 +11656,12 @@ def setOccupancy():
1171111656 % (minElements,ss.numVgprsPerElement))
1171211657 self.vgprPool.growPool(0, minElements, ss.numVgprsPerElement, \
1171311658 "grow-pool for GlobalWrite")
11714- maxVgprs, occupancy = setOccupancy()
11659+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11660+ if maxVgprs != maxVgprsN:
11661+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11662+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
1171511663 numVgprAvailable = self.vgprPool.available()
1171611664
11717- # set atomicW after we potentially resize GWVW
11718- atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11719-
1172011665 # print("NumVgprAvailable", numVgprAvailable)
1172111666 if ss.numVgprsPerElement:
1172211667 numElementsPerBatch = numVgprAvailable // ss.numVgprsPerElement
@@ -11761,6 +11706,7 @@ def setOccupancy():
1176111706
1176211707 # check best numElementsPerBatch to handle a column block
1176311708 # elements of column block must be multiple size of numElementsPerBatch
11709+ nBatchesPerRow = 0
1176411710 if kernel["StoreRemapVectorWidth"]:
1176511711 firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0)
1176611712 # find the largest factor and smaller than numElementPerBatch
@@ -11785,19 +11731,106 @@ def setOccupancy():
1178511731 totalNeededVgpr = ss.numVgprsPerElement * numElementsPerBatch
1178611732 # print("Available vgprs =", numVgprAvailable, "Needed vgprs =", totalNeededVgpr, "pool size =", self.vgprPool.size())
1178711733 if numVgprAvailable < totalNeededVgpr:
11734+ self.vgprPool.resetOccupancyLimit()
1178811735 print2("info: growing pool += %d * %d for GlobalWrite\n" \
1178911736 % (numBatches,ss.numVgprsPerElement))
1179011737 availableBlock = min(0, self.vgprPool.available() - numVgprAvailable)
1179111738 self.vgprPool.growPool(0, totalNeededVgpr + availableBlock, 1, "grow-pool for GlobalWrite")
11739+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11740+ if maxVgprs != maxVgprsN:
11741+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11742+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11743+
1179211744 # # Get true numVgprAvailable
1179311745 # numVgprAvailable = self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align)
1179411746 # print("Available vgprs =", numVgprAvailable, "pool size =", self.vgprPool.size())
1179511747
1179611748 numSgprs = ss.cfg.numTempSgprPerBatch + ss.cfg.numMaskSgprPerBatch + ss.cfg.numMaskSgprPerElement * numElementsPerBatch
1179711749
11750+ if actPCMaxTempSgpr:
11751+ numSgprs = max(actPCMaxTempSgpr, numSgprs)
11752+
11753+ self.sgprPool.resetOccupancyLimit()
11754+ self.sgprPool.checkIn(self.sgprPool.checkOutAligned(numSgprs, 2, preventOverflow=False))
11755+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11756+ if maxVgprs != maxVgprsN:
11757+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11758+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11759+ return numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs
11760+
11761+
11762+ ##############################################################################
11763+ # globalWriteElementBatch :
11764+ ##############################################################################
11765+ def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11766+ applyAlpha, beta, edge, atomic, \
11767+ vectorWidths, elements, activationLabelList, \
11768+ tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11769+ actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11770+ edgeModule, writeLabels, endLabel, \
11771+ edge_mode_pos, currentInstLength, \
11772+ idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11773+ factorDim = factorDims[idx2]
11774+ edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11775+ if idx2 == 0:
11776+ edge_mode_pos = len(edgeModule.items())
11777+
11778+ # for storeRemap edge case, non-beta still can enable vector stores
11779+ if kernel["StoreRemapVectorWidth"] and not beta:
11780+ edgeI = False
11781+ else:
11782+ edgeI = edge
11783+ #edgeI = True # set to True to disable vector stores
11784+ gwvw = vectorWidths[edgeI]
11785+
11786+ #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11787+
11788+ ########################################
11789+ # Calculate Vgprs for Write Batching
11790+ ########################################
11791+ self.vgprPool.resetOccupancyLimit()
11792+ self.sgprPool.resetOccupancyLimit()
11793+
11794+ # Temporarily grow pool for sgpr
11795+ sgprList = []
11796+ if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11797+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11798+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11799+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11800+ sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11801+ sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11802+ sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11803+ for s in sgprList:
11804+ self.sgprPool.checkIn(s)
11805+ if actPCMaxTempSgpr > 0:
11806+ self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11807+
11808+ tmpVgprDynamic = None
11809+ tmpVgprDynamicSize = 0
11810+ tmpVgprDynamicAlign = 0
11811+ if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11812+ GSUTotal = self.getMBSKGSUTotal(kernel)
11813+ vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11814+ tmpVgprDynamicSize = vgprMbsk
11815+ tmpVgprDynamicAlign = 4
11816+ if tmpVgprDynamicSize > 0:
11817+ tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11818+
11819+ maxVgprs, occupancy = self.setOccupancy(kernel)
11820+
11821+ ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11822+
11823+ actPCMaxTempSgpr_ = None
1179811824 if activationLabelList and isInsertActFunctionCallAddrCalc:
1179911825 assert activationSetPCStruct, activationEnumStrList and activationLabelList and toActModuleList
11800- numSgprs = max(actPCMaxTempSgpr, numSgprs)
11826+ actPCMaxTempSgpr_ = actPCMaxTempSgpr
11827+
11828+ numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs = self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr_, edgeI, gwvw, maxVgprs, ss)
11829+
11830+ # set atomicW after we potentially resize GWVW
11831+ atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11832+
11833+ if activationLabelList and isInsertActFunctionCallAddrCalc:
1180111834 edgeModule.add(self.insertActFunctionCallAddrCalc(activationSetPCStruct.sgprOffsetActivation, \
1180211835 gwvw, toActModuleList, activationEnumStrList, activationLabelList, \
1180311836 idx0, idx1))
0 commit comments