Skip to content

Commit d8b6766

Browse files
committed
Fix StreamK+LSU and add new test case
1 parent faf7381 commit d8b6766

File tree

2 files changed

+145
-4
lines changed

2 files changed

+145
-4
lines changed

tensilelite/Tensile/Components/StreamK.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def partialsWriteBatch(self, writer, kernel, ss, batchIdx, applyAlpha, beta, edg
688688
# AccVgpr read
689689
# if kernel.enabledSetPrioSplitLDS:
690690
# kStr += inst("s_setprio", "0", "")
691-
if codeAccVgprRead is not None: # and writer.kernel["LocalSplitU"] == 1
691+
if codeAccVgprRead is not None and kernel["LocalSplitU"] == 1:
692692
regsPerScalar = writer.states.bpeCinternal // writer.states.bpr # register per scalar
693693
# loop over store instructions within one batch
694694
for elementIdx in range(0, len(batchElements)):
@@ -757,7 +757,7 @@ def partialsWriteBatch(self, writer, kernel, ss, batchIdx, applyAlpha, beta, edg
757757
module.add(VMovB32(vgpr(cvtVgprStruct.vgprBF8Min), "0xc7600000", comment="BF8 Min value -57344 as float32" ))
758758

759759
if kernel["EnableMatrixInstruction"]:
760-
WaveNum = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1]
760+
WaveNum = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] * kernel["WorkGroup"][2]
761761
else:
762762
WaveNum = kernel["NumThreads"] // kernel["WavefrontSize"]
763763

@@ -777,6 +777,7 @@ def partialsWriteBatch(self, writer, kernel, ss, batchIdx, applyAlpha, beta, edg
777777
module.add(SMovB32(dst=sgpr(tmpS01), src=0, comment="Init sgpr offset"))
778778
else:
779779
increment = (kernel["WavefrontSize"] * WaveNum) * storeWidth * writer.states.bpeCinternal
780+
# module.addComment1("WavefrontSize={}, WaveNum={}, storeWidth={}, bpeC={}".format(kernel["WavefrontSize"], WaveNum, storeWidth, writer.states.bpeCinternal))
780781
module.add(SAddU32(dst=sgpr(tmpS01), src0=sgpr(tmpS01), src1=increment, comment="Inc sgpr offset"))
781782

782783
# TODO StreamK need this packing code???
@@ -1131,7 +1132,7 @@ def fixupBatch(self, writer, kernel, ss, batchIdx, edge, gwvw, \
11311132
# self.StoreCUnrollLoadCWaitComment = "waitcnt for LoadC" # this will be used later to identify waitcnt for loadC
11321133

11331134
if kernel["EnableMatrixInstruction"]:
1134-
WaveNum = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1]
1135+
WaveNum = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] * kernel["WorkGroup"][2]
11351136
else:
11361137
WaveNum = kernel["NumThreads"] // kernel["WavefrontSize"]
11371138

@@ -1156,6 +1157,7 @@ def fixupBatch(self, writer, kernel, ss, batchIdx, edge, gwvw, \
11561157
module.add(SMovB32(dst=sgpr(tmpS01), src=0, comment="Init sgpr offset"))
11571158
else:
11581159
increment = (kernel["WavefrontSize"] * WaveNum) * storeWidth * writer.states.bpeCinternal
1160+
# module.addComment1("WavefrontSize={}, WaveNum={}, storeWidth={}, bpeC={}".format(kernel["WavefrontSize"], WaveNum, storeWidth, writer.states.bpeCinternal))
11591161
module.add(SAddU32(dst=sgpr(tmpS01), src0=sgpr(tmpS01), src1=increment, comment="Inc sgpr offset"))
11601162

11611163
module.add(writer.readInput(kernel, ss, 'WS', kernel["ProblemType"]["ComputeDataType"], addrCalc, vc0, data, gwvw, addrCVgpr, sgpr(tmpS01)))
@@ -1165,7 +1167,7 @@ def fixupBatch(self, writer, kernel, ss, batchIdx, edge, gwvw, \
11651167
# AccVgpr read
11661168
# if kernel.enabledSetPrioSplitLDS:
11671169
# kStr += inst("s_setprio", "0", "")
1168-
if codeAccVgprRead is not None:
1170+
if codeAccVgprRead is not None and kernel["LocalSplitU"] == 1:
11691171
regsPerScalar = writer.states.bpeCinternal // writer.states.bpr # register per scalar
11701172
# loop over store instructions within one batch
11711173
for elementIdx in range(0, len(batchElements)):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
TestParameters:
2+
marks: [skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] # not supported by arch
3+
4+
GlobalParameters:
5+
NumElementsToValidate: -1
6+
BoundsCheck: False
7+
KernelTime: False
8+
DataInitTypeAlpha: 1
9+
DataInitTypeBeta: 1
10+
DataInitTypeA: 12
11+
DataInitTypeB: 13
12+
DataInitTypeC: 12
13+
# DataInitTypeC: 1
14+
# ValidationPrintValids: True
15+
MaxWorkspaceSize: 134217728
16+
# PrintSolutionRejectionReason: True
17+
# ForceGenerateKernel: True
18+
# GenerateSourcesAndExit: True
19+
NumWarmups: 0
20+
EnqueuesPerSync: 1
21+
# NumBenchmarks: 10
22+
SleepPercent: 50
23+
24+
BenchmarkProblems:
25+
26+
- # HGEMM NT
27+
- # ProblemType
28+
OperationType: GEMM
29+
DataType: h
30+
DestDataType: h
31+
ComputeDataType: s
32+
HighPrecisionAccumulate: True
33+
TransposeA: False
34+
TransposeB: True
35+
UseBeta: True
36+
Batched: True
37+
38+
- # HGEMM NT - Single wave
39+
InitialSolutionParameters:
40+
BenchmarkCommonParameters:
41+
- KernelLanguage: ["Assembly"]
42+
- PrefetchLocalRead: [True]
43+
ForkParameters:
44+
- MatrixInstruction:
45+
- [16,16,16,1, 1, 1,1, 1,1]
46+
# - [16,16,16,1, 1, 2,1, 1,1]
47+
# - [16,16,16,1, 1, 4,1, 1,1]
48+
# - [16,16,16,1, 1, 1,2, 1,1]
49+
# - [16,16,16,1, 1, 1,4, 1,1]
50+
# - [16,16,16,1, 1, 2,2, 1,1]
51+
# - [16,16,16,1, 1, 1,1, 2,1]
52+
# - [16,16,16,1, 1, 1,1, 1,2]
53+
- DepthU: [256]
54+
- 1LDSBuffer: [-1]
55+
- ClusterLocalRead: [True]
56+
- ExpandPointerSwap: [0]
57+
# - LocalReadVectorWidth: [4, 8]
58+
# - NumElementsPerBatchStore: [0, 16]
59+
- PrefetchGlobalRead: [2]
60+
- PrefetchLocalRead: [1]
61+
- ScheduleIterAlg: [3]
62+
- SourceSwap: [True]
63+
- StaggerU: [0]
64+
- StreamK: [3]
65+
- WorkGroupMappingXCC: [8]
66+
- TransposeLDS: [-1]
67+
- UseSgprForGRO: [0]
68+
- WorkGroupMapping: [6]
69+
- VectorWidthA: [1]
70+
- VectorWidthB: [1]
71+
72+
BenchmarkForkParameters:
73+
JoinParameters:
74+
BenchmarkJoinParameters:
75+
BenchmarkFinalParameters:
76+
- ProblemSizes:
77+
- Exact: [512, 512, 1, 512]
78+
- Exact: [1024, 1024, 1, 1024]
79+
- Exact: [1031, 1031, 1, 1031]
80+
# - Exact: [4096, 4096, 1, 1024]
81+
# - Exact: [4103, 4096, 1, 1024]
82+
# - Exact: [4096, 4103, 1, 1024]
83+
# - Exact: [4096, 4096, 1, 1031]
84+
- Exact: [4103, 4103, 1, 1031]
85+
86+
- # HGEMM NT - LSU
87+
InitialSolutionParameters:
88+
BenchmarkCommonParameters:
89+
- KernelLanguage: ["Assembly"]
90+
- PrefetchLocalRead: [True]
91+
ForkParameters:
92+
- MatrixInstruction:
93+
- [16,16,16,1, 1, 1,1, 1,1]
94+
# - [16,16,16,1, 1, 1,1, 1,2]
95+
# - [16,16,16,1, 1, 1,1, 2,1]
96+
- [16,16,16,1, 1, 2,1, 1,1]
97+
- [16,16,16,1, 1, 4,1, 1,1]
98+
- [16,16,16,1, 1, 1,2, 1,1]
99+
- [16,16,16,1, 1, 1,4, 1,1]
100+
- [16,16,16,1, 1, 2,2, 1,1]
101+
# - [16,16,16,1, 1, 1,1, 2,1]
102+
# - [16,16,16,1, 1, 1,1, 1,2]
103+
# - [16,16,16,1, 1, 8,8, 1,1]
104+
# - [16,16,16,1, 1, 8,8, 2,1]
105+
# - [16,16,16,1, 1, 8,8, 1,2]
106+
- DepthU: [256]
107+
- WorkGroup:
108+
- [4,4,4]
109+
- 1LDSBuffer: [-1]
110+
- ClusterLocalRead: [True]
111+
- ExpandPointerSwap: [0]
112+
# - LocalReadVectorWidth: [4, 8]
113+
# - NumElementsPerBatchStore: [0, 16]
114+
- PrefetchGlobalRead: [2]
115+
- PrefetchLocalRead: [1]
116+
- ScheduleIterAlg: [3]
117+
- SourceSwap: [True]
118+
- StaggerU: [0]
119+
- StreamK: [3]
120+
- WorkGroupMappingXCC: [8]
121+
- TransposeLDS: [-1]
122+
- UseSgprForGRO: [0]
123+
- WorkGroupMapping: [6]
124+
- VectorWidthA: [1]
125+
- VectorWidthB: [1]
126+
127+
BenchmarkForkParameters:
128+
JoinParameters:
129+
BenchmarkJoinParameters:
130+
BenchmarkFinalParameters:
131+
- ProblemSizes:
132+
- Exact: [512, 512, 1, 512]
133+
- Exact: [1024, 1024, 1, 1024]
134+
- Exact: [1031, 1031, 1, 1031]
135+
# - Exact: [4096, 4096, 1, 1024]
136+
# - Exact: [4103, 4096, 1, 1024]
137+
# - Exact: [4096, 4103, 1, 1024]
138+
# - Exact: [4096, 4096, 1, 1031]
139+
- Exact: [4103, 4103, 1, 1031]

0 commit comments

Comments
 (0)