@@ -10253,6 +10253,8 @@ def cleanupGlobalWrite(self, kernel):
10253
10253
self.vgprPool.checkIn(self.vgprs.storeRemapCoord0)
10254
10254
self.vgprPool.checkIn(self.vgprs.storeRemapCoord1)
10255
10255
self.vgprPool.checkIn(self.vgprs.storeRemapOffsetCoord1)
10256
+ for v in self.vgprs.storeRemapAS:
10257
+ self.vgprPool.checkIn(v)
10256
10258
if kernel["BufferStore"]:
10257
10259
self.vgprPool.checkIn(self.vgprs.cinRowPtr)
10258
10260
self.vgprPool.checkIn(self.vgprs.coutRowPtrD)
@@ -10374,10 +10376,7 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
10374
10376
rpe = self.states.bpeCexternal / self.states.bpr
10375
10377
rpv = rpe * gwvw
10376
10378
10377
- # num registers to check out
10378
- storeRegs = []
10379
- for i in range(0, nElements, gwvw):
10380
- storeRegs.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10379
+ storeRegs = self.vgprs.storeRemapAS
10381
10380
src = vgpr(self.vgprs.storeRemapLR)
10382
10381
for rIdx, i in enumerate(range(0, nElements, gwvw)):
10383
10382
offset = self.storeRemapLrOffset * bpe * (i//gwvw)
@@ -10482,8 +10481,6 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
10482
10481
10483
10482
module.addSpaceLine()
10484
10483
self.vgprPool.checkIn(vTmp)
10485
- for v in storeRegs:
10486
- self.vgprPool.checkIn(v)
10487
10484
10488
10485
#Data exchange between different waves
10489
10486
#Make sure LDS reads are finished of all waves
@@ -10628,6 +10625,12 @@ def storeRemapComputeStoreVgprs(self, kernel):
10628
10625
10629
10626
self.vgprPool.checkIn(tmpV0)
10630
10627
10628
+ nElements = kernel["MacroTile0"]*kernel["MatrixInstN"]//kernel["MIWaveGroup"][0]//self.states.kernel["WavefrontSize"]
10629
+ rpe = self.states.bpeCexternal / self.states.bpr
10630
+ rpv = rpe * gwvw
10631
+ self.vgprs.storeRemapAS = []
10632
+ for i in range(0, nElements, gwvw):
10633
+ self.vgprs.storeRemapAS.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10631
10634
return module
10632
10635
10633
10636
##############################################################################
@@ -11620,77 +11623,19 @@ def getMBSKGSUTotal(self, kernel):
11620
11623
GSUtotal = max(2,GSUtotal)
11621
11624
return GSUtotal
11622
11625
11623
- ##############################################################################
11624
- # globalWriteElementBatch :
11625
- ##############################################################################
11626
- def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11627
- applyAlpha, beta, edge, atomic, \
11628
- vectorWidths, elements, activationLabelList, \
11629
- tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11630
- actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11631
- edgeModule, writeLabels, endLabel, \
11632
- edge_mode_pos, currentInstLength, \
11633
- idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11634
- factorDim = factorDims[idx2]
11635
- edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11636
- if idx2 == 0:
11637
- edge_mode_pos = len(edgeModule.items())
11638
-
11639
- # for storeRemap edge case, non-beta still can enable vector stores
11640
- if kernel["StoreRemapVectorWidth"] and not beta:
11641
- edgeI = False
11642
- else:
11643
- edgeI = edge
11644
- #edgeI = True # set to True to disable vector stores
11645
- gwvw = vectorWidths[edgeI]
11646
-
11647
- #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11648
-
11649
- ########################################
11650
- # Calculate Vgprs for Write Batching
11651
- ########################################
11652
- self.vgprPool.resetOccupancyLimit()
11653
- self.sgprPool.resetOccupancyLimit()
11654
-
11655
- # Temporarily grow pool for sgpr
11656
- sgprList = []
11657
- if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11658
- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11659
- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11660
- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11661
- sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11662
- sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11663
- sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11664
- for s in sgprList:
11665
- self.sgprPool.checkIn(s)
11666
- if actPCMaxTempSgpr > 0:
11667
- self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11668
-
11669
- tmpVgprDynamic = None
11670
- tmpVgprDynamicSize = 0
11671
- tmpVgprDynamicAlign = 0
11672
- if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11673
- GSUTotal = self.getMBSKGSUTotal(kernel)
11674
- vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11675
- tmpVgprDynamicSize = vgprMbsk
11676
- tmpVgprDynamicAlign = 4
11677
- if tmpVgprDynamicSize > 0:
11678
- tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11679
-
11680
- ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11681
-
11682
- def setOccupancy():
11683
- # Use VGPR up to next occupancy threshold:
11684
- maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11685
- self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11686
- # Set occupancy limit for register pools
11687
- # TODO: Support gfx12
11688
- if kernel["ISA"][0] != 12:
11689
- self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11690
- self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11691
- return maxVgprs, occupancy
11692
-
11693
- maxVgprs, occupancy = setOccupancy()
11626
+ def setOccupancy(self, kernel):
11627
+ # Use VGPR up to next occupancy threshold:
11628
+ maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11629
+ self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11630
+ # Set occupancy limit for register pools
11631
+ # TODO: Support gfx12
11632
+ if kernel["ISA"][0] != 12:
11633
+ self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11634
+ self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11635
+ return maxVgprs, occupancy
11636
+
11637
+ def refineOccupancy(self, kernel, atomic, elements, actPCMaxTempSgpr, \
11638
+ edgeI, gwvw, maxVgprs, ss):
11694
11639
# Get estimated numVgprAvailable
11695
11640
# print("Max vgprs =", maxVgprs, self.vgprPool.size(), self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align))
11696
11641
numVgprAvailable = self.vgprPool.availableBlockMaxVgpr(maxVgprs, ss.numVgprsPerElement, ss.align)
@@ -11711,12 +11656,12 @@ def setOccupancy():
11711
11656
% (minElements,ss.numVgprsPerElement))
11712
11657
self.vgprPool.growPool(0, minElements, ss.numVgprsPerElement, \
11713
11658
"grow-pool for GlobalWrite")
11714
- maxVgprs, occupancy = setOccupancy()
11659
+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11660
+ if maxVgprs != maxVgprsN:
11661
+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11662
+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11715
11663
numVgprAvailable = self.vgprPool.available()
11716
11664
11717
- # set atomicW after we potentially resize GWVW
11718
- atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11719
-
11720
11665
# print("NumVgprAvailable", numVgprAvailable)
11721
11666
if ss.numVgprsPerElement:
11722
11667
numElementsPerBatch = numVgprAvailable // ss.numVgprsPerElement
@@ -11761,6 +11706,7 @@ def setOccupancy():
11761
11706
11762
11707
# check best numElementsPerBatch to handle a column block
11763
11708
# elements of column block must be multiple size of numElementsPerBatch
11709
+ nBatchesPerRow = 0
11764
11710
if kernel["StoreRemapVectorWidth"]:
11765
11711
firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0)
11766
11712
# find the largest factor and smaller than numElementPerBatch
@@ -11785,19 +11731,106 @@ def setOccupancy():
11785
11731
totalNeededVgpr = ss.numVgprsPerElement * numElementsPerBatch
11786
11732
# print("Available vgprs =", numVgprAvailable, "Needed vgprs =", totalNeededVgpr, "pool size =", self.vgprPool.size())
11787
11733
if numVgprAvailable < totalNeededVgpr:
11734
+ self.vgprPool.resetOccupancyLimit()
11788
11735
print2("info: growing pool += %d * %d for GlobalWrite\n" \
11789
11736
% (numBatches,ss.numVgprsPerElement))
11790
11737
availableBlock = min(0, self.vgprPool.available() - numVgprAvailable)
11791
11738
self.vgprPool.growPool(0, totalNeededVgpr + availableBlock, 1, "grow-pool for GlobalWrite")
11739
+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11740
+ if maxVgprs != maxVgprsN:
11741
+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11742
+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11743
+
11792
11744
# # Get true numVgprAvailable
11793
11745
# numVgprAvailable = self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align)
11794
11746
# print("Available vgprs =", numVgprAvailable, "pool size =", self.vgprPool.size())
11795
11747
11796
11748
numSgprs = ss.cfg.numTempSgprPerBatch + ss.cfg.numMaskSgprPerBatch + ss.cfg.numMaskSgprPerElement * numElementsPerBatch
11797
11749
11750
+ if actPCMaxTempSgpr:
11751
+ numSgprs = max(actPCMaxTempSgpr, numSgprs)
11752
+
11753
+ self.sgprPool.resetOccupancyLimit()
11754
+ self.sgprPool.checkIn(self.sgprPool.checkOutAligned(numSgprs, 2, preventOverflow=False))
11755
+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11756
+ if maxVgprs != maxVgprsN:
11757
+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11758
+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11759
+ return numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs
11760
+
11761
+
11762
+ ##############################################################################
11763
+ # globalWriteElementBatch :
11764
+ ##############################################################################
11765
+ def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11766
+ applyAlpha, beta, edge, atomic, \
11767
+ vectorWidths, elements, activationLabelList, \
11768
+ tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11769
+ actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11770
+ edgeModule, writeLabels, endLabel, \
11771
+ edge_mode_pos, currentInstLength, \
11772
+ idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11773
+ factorDim = factorDims[idx2]
11774
+ edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11775
+ if idx2 == 0:
11776
+ edge_mode_pos = len(edgeModule.items())
11777
+
11778
+ # for storeRemap edge case, non-beta still can enable vector stores
11779
+ if kernel["StoreRemapVectorWidth"] and not beta:
11780
+ edgeI = False
11781
+ else:
11782
+ edgeI = edge
11783
+ #edgeI = True # set to True to disable vector stores
11784
+ gwvw = vectorWidths[edgeI]
11785
+
11786
+ #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11787
+
11788
+ ########################################
11789
+ # Calculate Vgprs for Write Batching
11790
+ ########################################
11791
+ self.vgprPool.resetOccupancyLimit()
11792
+ self.sgprPool.resetOccupancyLimit()
11793
+
11794
+ # Temporarily grow pool for sgpr
11795
+ sgprList = []
11796
+ if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11797
+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11798
+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11799
+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11800
+ sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11801
+ sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11802
+ sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11803
+ for s in sgprList:
11804
+ self.sgprPool.checkIn(s)
11805
+ if actPCMaxTempSgpr > 0:
11806
+ self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11807
+
11808
+ tmpVgprDynamic = None
11809
+ tmpVgprDynamicSize = 0
11810
+ tmpVgprDynamicAlign = 0
11811
+ if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11812
+ GSUTotal = self.getMBSKGSUTotal(kernel)
11813
+ vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11814
+ tmpVgprDynamicSize = vgprMbsk
11815
+ tmpVgprDynamicAlign = 4
11816
+ if tmpVgprDynamicSize > 0:
11817
+ tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11818
+
11819
+ maxVgprs, occupancy = self.setOccupancy(kernel)
11820
+
11821
+ ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11822
+
11823
+ actPCMaxTempSgpr_ = None
11798
11824
if activationLabelList and isInsertActFunctionCallAddrCalc:
11799
11825
assert activationSetPCStruct, activationEnumStrList and activationLabelList and toActModuleList
11800
- numSgprs = max(actPCMaxTempSgpr, numSgprs)
11826
+ actPCMaxTempSgpr_ = actPCMaxTempSgpr
11827
+
11828
+ numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs = self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr_, edgeI, gwvw, maxVgprs, ss)
11829
+
11830
+ # set atomicW after we potentially resize GWVW
11831
+ atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11832
+
11833
+ if activationLabelList and isInsertActFunctionCallAddrCalc:
11801
11834
edgeModule.add(self.insertActFunctionCallAddrCalc(activationSetPCStruct.sgprOffsetActivation, \
11802
11835
gwvw, toActModuleList, activationEnumStrList, activationLabelList, \
11803
11836
idx0, idx1))
0 commit comments