@@ -10218,6 +10218,8 @@ def cleanupGlobalWrite(self, kernel):
10218
10218
self.vgprPool.checkIn(self.vgprs.storeRemapCoord0)
10219
10219
self.vgprPool.checkIn(self.vgprs.storeRemapCoord1)
10220
10220
self.vgprPool.checkIn(self.vgprs.storeRemapOffsetCoord1)
10221
+ for v in self.vgprs.storeRemapAS:
10222
+ self.vgprPool.checkIn(v)
10221
10223
if kernel["BufferStore"]:
10222
10224
self.vgprPool.checkIn(self.vgprs.cinRowPtr)
10223
10225
self.vgprPool.checkIn(self.vgprs.coutRowPtrD)
@@ -10339,10 +10341,7 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
10339
10341
rpe = self.states.bpeCexternal / self.states.bpr
10340
10342
rpv = rpe * gwvw
10341
10343
10342
- # num registers to check out
10343
- storeRegs = []
10344
- for i in range(0, nElements, gwvw):
10345
- storeRegs.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10344
+ storeRegs = self.vgprs.storeRemapAS
10346
10345
src = vgpr(self.vgprs.storeRemapLR)
10347
10346
for rIdx, i in enumerate(range(0, nElements, gwvw)):
10348
10347
offset = self.storeRemapLrOffset * bpe * (i//gwvw)
@@ -10447,8 +10446,6 @@ def storeRemapAddStore(self, kernel, tmpVgpr, tmpS01, edge, StoreRemapLastBatch)
10447
10446
10448
10447
module.addSpaceLine()
10449
10448
self.vgprPool.checkIn(vTmp)
10450
- for v in storeRegs:
10451
- self.vgprPool.checkIn(v)
10452
10449
10453
10450
#Data exchange between different waves
10454
10451
#Make sure LDS reads are finished of all waves
@@ -10593,6 +10590,12 @@ def storeRemapComputeStoreVgprs(self, kernel):
10593
10590
10594
10591
self.vgprPool.checkIn(tmpV0)
10595
10592
10593
+ nElements = kernel["MacroTile0"]*kernel["MatrixInstN"]//kernel["MIWaveGroup"][0]//self.states.kernel["WavefrontSize"]
10594
+ rpe = self.states.bpeCexternal / self.states.bpr
10595
+ rpv = rpe * gwvw
10596
+ self.vgprs.storeRemapAS = []
10597
+ for i in range(0, nElements, gwvw):
10598
+ self.vgprs.storeRemapAS.append(self.vgprPool.checkOutAligned(int(rpv), int(rpv), "store element d"))
10596
10599
return module
10597
10600
10598
10601
##############################################################################
@@ -11585,77 +11588,19 @@ def getMBSKGSUTotal(self, kernel):
11585
11588
GSUtotal = max(2,GSUtotal)
11586
11589
return GSUtotal
11587
11590
11588
- ##############################################################################
11589
- # globalWriteElementBatch :
11590
- ##############################################################################
11591
- def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11592
- applyAlpha, beta, edge, atomic, \
11593
- vectorWidths, elements, activationLabelList, \
11594
- tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11595
- actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11596
- edgeModule, writeLabels, endLabel, \
11597
- edge_mode_pos, currentInstLength, \
11598
- idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11599
- factorDim = factorDims[idx2]
11600
- edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11601
- if idx2 == 0:
11602
- edge_mode_pos = len(edgeModule.items())
11603
-
11604
- # for storeRemap edge case, non-beta still can enable vector stores
11605
- if kernel["StoreRemapVectorWidth"] and not beta:
11606
- edgeI = False
11607
- else:
11608
- edgeI = edge
11609
- #edgeI = True # set to True to disable vector stores
11610
- gwvw = vectorWidths[edgeI]
11611
-
11612
- #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11613
-
11614
- ########################################
11615
- # Calculate Vgprs for Write Batching
11616
- ########################################
11617
- self.vgprPool.resetOccupancyLimit()
11618
- self.sgprPool.resetOccupancyLimit()
11619
-
11620
- # Temporarily grow pool for sgpr
11621
- sgprList = []
11622
- if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11623
- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11624
- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11625
- sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11626
- sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11627
- sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11628
- sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11629
- for s in sgprList:
11630
- self.sgprPool.checkIn(s)
11631
- if actPCMaxTempSgpr > 0:
11632
- self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11633
-
11634
- tmpVgprDynamic = None
11635
- tmpVgprDynamicSize = 0
11636
- tmpVgprDynamicAlign = 0
11637
- if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11638
- GSUTotal = self.getMBSKGSUTotal(kernel)
11639
- vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11640
- tmpVgprDynamicSize = vgprMbsk
11641
- tmpVgprDynamicAlign = 4
11642
- if tmpVgprDynamicSize > 0:
11643
- tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11644
-
11645
- ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11646
-
11647
- def setOccupancy():
11648
- # Use VGPR up to next occupancy threshold:
11649
- maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11650
- self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11651
- # Set occupancy limit for register pools
11652
- # TODO: Support gfx12
11653
- if kernel["ISA"][0] != 12:
11654
- self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11655
- self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11656
- return maxVgprs, occupancy
11657
-
11658
- maxVgprs, occupancy = setOccupancy()
11591
+ def setOccupancy(self, kernel):
11592
+ # Use VGPR up to next occupancy threshold:
11593
+ maxVgprs, occupancy = self.getMaxRegsForOccupancy(kernel["NumThreads"], self.vgprPool.size(), self.sgprPool.size(), \
11594
+ self.getLdsSize(kernel), self.agprPool.size(), self.states.doubleVgpr)
11595
+ # Set occupancy limit for register pools
11596
+ # TODO: Support gfx12
11597
+ if kernel["ISA"][0] != 12:
11598
+ self.vgprPool.setOccupancyLimit(self.states.regCaps["MaxVgpr"], self.states.regCaps["PhysicalMaxVgpr"] // occupancy)
11599
+ self.sgprPool.setOccupancyLimit(self.states.regCaps["MaxSgpr"], self.states.regCaps["PhysicalMaxSgpr"] // occupancy)
11600
+ return maxVgprs, occupancy
11601
+
11602
+ def refineOccupancy(self, kernel, atomic, elements, actPCMaxTempSgpr, \
11603
+ edgeI, gwvw, maxVgprs, ss):
11659
11604
# Get estimated numVgprAvailable
11660
11605
# print("Max vgprs =", maxVgprs, self.vgprPool.size(), self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align))
11661
11606
numVgprAvailable = self.vgprPool.availableBlockMaxVgpr(maxVgprs, ss.numVgprsPerElement, ss.align)
@@ -11676,12 +11621,12 @@ def setOccupancy():
11676
11621
% (minElements,ss.numVgprsPerElement))
11677
11622
self.vgprPool.growPool(0, minElements, ss.numVgprsPerElement, \
11678
11623
"grow-pool for GlobalWrite")
11679
- maxVgprs, occupancy = setOccupancy()
11624
+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11625
+ if maxVgprs != maxVgprsN:
11626
+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11627
+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11680
11628
numVgprAvailable = self.vgprPool.available()
11681
11629
11682
- # set atomicW after we potentially resize GWVW
11683
- atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11684
-
11685
11630
# print("NumVgprAvailable", numVgprAvailable)
11686
11631
if ss.numVgprsPerElement:
11687
11632
numElementsPerBatch = numVgprAvailable // ss.numVgprsPerElement
@@ -11726,6 +11671,7 @@ def setOccupancy():
11726
11671
11727
11672
# check best numElementsPerBatch to handle a column block
11728
11673
# elements of column block must be multiple size of numElementsPerBatch
11674
+ nBatchesPerRow = 0
11729
11675
if kernel["StoreRemapVectorWidth"]:
11730
11676
firstRow = [e for e in elements[edgeI] if e[0]==0 and e[2]==0] # format for element = (tt1, tt0, vc1, vc0)
11731
11677
# find the largest factor and smaller than numElementPerBatch
@@ -11750,19 +11696,106 @@ def setOccupancy():
11750
11696
totalNeededVgpr = ss.numVgprsPerElement * numElementsPerBatch
11751
11697
# print("Available vgprs =", numVgprAvailable, "Needed vgprs =", totalNeededVgpr, "pool size =", self.vgprPool.size())
11752
11698
if numVgprAvailable < totalNeededVgpr:
11699
+ self.vgprPool.resetOccupancyLimit()
11753
11700
print2("info: growing pool += %d * %d for GlobalWrite\n" \
11754
11701
% (numBatches,ss.numVgprsPerElement))
11755
11702
availableBlock = min(0, self.vgprPool.available() - numVgprAvailable)
11756
11703
self.vgprPool.growPool(0, totalNeededVgpr + availableBlock, 1, "grow-pool for GlobalWrite")
11704
+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11705
+ if maxVgprs != maxVgprsN:
11706
+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11707
+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11708
+
11757
11709
# # Get true numVgprAvailable
11758
11710
# numVgprAvailable = self.vgprPool.availableBlock(ss.numVgprsPerElement, ss.align)
11759
11711
# print("Available vgprs =", numVgprAvailable, "pool size =", self.vgprPool.size())
11760
11712
11761
11713
numSgprs = ss.cfg.numTempSgprPerBatch + ss.cfg.numMaskSgprPerBatch + ss.cfg.numMaskSgprPerElement * numElementsPerBatch
11762
11714
11715
+ if actPCMaxTempSgpr:
11716
+ numSgprs = max(actPCMaxTempSgpr, numSgprs)
11717
+
11718
+ self.sgprPool.resetOccupancyLimit()
11719
+ self.sgprPool.checkIn(self.sgprPool.checkOutAligned(numSgprs, 2, preventOverflow=False))
11720
+ maxVgprsN, occupancy = self.setOccupancy(kernel)
11721
+ if maxVgprs != maxVgprsN:
11722
+ #print("refineOccupancy maxVgprs, new", maxVgprsN, "old", maxVgprs)
11723
+ return self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr, edgeI, gwvw, maxVgprsN, ss)
11724
+ return numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs
11725
+
11726
+
11727
+ ##############################################################################
11728
+ # globalWriteElementBatch :
11729
+ ##############################################################################
11730
+ def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
11731
+ applyAlpha, beta, edge, atomic, \
11732
+ vectorWidths, elements, activationLabelList, \
11733
+ tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList, \
11734
+ actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
11735
+ edgeModule, writeLabels, endLabel, \
11736
+ edge_mode_pos, currentInstLength, \
11737
+ idx0, idx1, idx2, idxMN, vectorDataTypes, factorDims):
11738
+ factorDim = factorDims[idx2]
11739
+ edgeModule.add(writeLabels[beta][edge][factorDim][idxMN])
11740
+ if idx2 == 0:
11741
+ edge_mode_pos = len(edgeModule.items())
11742
+
11743
+ # for storeRemap edge case, non-beta still can enable vector stores
11744
+ if kernel["StoreRemapVectorWidth"] and not beta:
11745
+ edgeI = False
11746
+ else:
11747
+ edgeI = edge
11748
+ #edgeI = True # set to True to disable vector stores
11749
+ gwvw = vectorWidths[edgeI]
11750
+
11751
+ #print "globalWriteElements: edge=", edge, "beta=", beta, "atomic=", atomic
11752
+
11753
+ ########################################
11754
+ # Calculate Vgprs for Write Batching
11755
+ ########################################
11756
+ self.vgprPool.resetOccupancyLimit()
11757
+ self.sgprPool.resetOccupancyLimit()
11758
+
11759
+ # Temporarily grow pool for sgpr
11760
+ sgprList = []
11761
+ if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11762
+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11763
+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11764
+ sgprList.append(self.sgprPool.checkOut(1, preventOverflow=False))
11765
+ sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11766
+ sgprList.append(self.sgprPool.checkOutAligned(2, 2, preventOverflow=False))
11767
+ sgprList.append(self.sgprPool.checkOutAligned(4, 4, preventOverflow=False))
11768
+ for s in sgprList:
11769
+ self.sgprPool.checkIn(s)
11770
+ if actPCMaxTempSgpr > 0:
11771
+ self.sgprPool.checkIn(self.sgprPool.checkOutAligned(actPCMaxTempSgpr, 2 if actPCMaxTempSgpr > 1 else 1, preventOverflow=False))
11772
+
11773
+ tmpVgprDynamic = None
11774
+ tmpVgprDynamicSize = 0
11775
+ tmpVgprDynamicAlign = 0
11776
+ if kernel["_GlobalAccumulation"] == 'MultipleBufferSingleKernel':
11777
+ GSUTotal = self.getMBSKGSUTotal(kernel)
11778
+ vgprMbsk = (GSUTotal-1) * gwvw * max(1, kernel["ProblemType"]["DestDataType"].numRegisters())
11779
+ tmpVgprDynamicSize = vgprMbsk
11780
+ tmpVgprDynamicAlign = 4
11781
+ if tmpVgprDynamicSize > 0:
11782
+ tmpVgprDynamic = ContinuousRegister(idx=self.vgprPool.checkOutAligned(tmpVgprDynamicSize, tmpVgprDynamicAlign), size=tmpVgprDynamicSize)
11783
+
11784
+ maxVgprs, occupancy = self.setOccupancy(kernel)
11785
+
11786
+ ss = StoreState(self, kernel, gwvw, edge, beta, atomic, elements[edgeI], vectorDataTypes, dim=factorDim)
11787
+
11788
+ actPCMaxTempSgpr_ = None
11763
11789
if activationLabelList and isInsertActFunctionCallAddrCalc:
11764
11790
assert activationSetPCStruct, activationEnumStrList and activationLabelList and toActModuleList
11765
- numSgprs = max(actPCMaxTempSgpr, numSgprs)
11791
+ actPCMaxTempSgpr_ = actPCMaxTempSgpr
11792
+
11793
+ numElementsPerBatch, nBatchesPerRow, numBatches, numSgprs = self.refineOccupancy(kernel, atomic, elements, actPCMaxTempSgpr_, edgeI, gwvw, maxVgprs, ss)
11794
+
11795
+ # set atomicW after we potentially resize GWVW
11796
+ atomicW = min(gwvw, self.getVectorAtomicWidth(kernel))
11797
+
11798
+ if activationLabelList and isInsertActFunctionCallAddrCalc:
11766
11799
edgeModule.add(self.insertActFunctionCallAddrCalc(activationSetPCStruct.sgprOffsetActivation, \
11767
11800
gwvw, toActModuleList, activationEnumStrList, activationLabelList, \
11768
11801
idx0, idx1))
0 commit comments