Skip to content

Commit 94935bb

Browse files
committed
Fix MBSK, used WavefrontSize to cacualte required synchornizer size.
1 parent 4160381 commit 94935bb

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tensilelite/Tensile/Contractions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def CompoundPredicates(cls, state, problemType):
469469
valuepredicates.append(int((state["NumElementsPerThread"])/state["NumElementsPerBatchStore"]))
470470
else:
471471
valuepredicates.append(1)
472-
valuepredicates.append(state["NumThreads"])
472+
valuepredicates.append(ceil(state["NumThreads"] / state["WavefrontSize"]))
473473
rv += [cls('SynchronizerSizeCheck', index=0, value=valuepredicates)]
474474

475475
if state["InternalSupportParams"]["KernArgsVersion"] >= 1 and \
@@ -605,8 +605,8 @@ def FromOriginalState(cls, d):
605605
globalAccum = 4
606606
pgr = int(d['PrefetchGlobalRead'])
607607
synchronizerSizePerWG = ceil((d['MIWaveTile'][0]*d['MIWaveTile'][1] if d['EnableMatrixInstruction'] else d['ThreadTile0']*d['ThreadTile1'] \
608-
* ceil((d['NumElementsPerThread'])/d['NumElementsPerBatchStore']) if d['NumElementsPerBatchStore'] != 0 else 1 \
609-
* ceil(d["NumThreads"] / 64)))
608+
* ceil((d['NumElementsPerThread'])/d['NumElementsPerBatchStore']) if d['NumElementsPerBatchStore'] != 0 else 1 \
609+
* ceil(d["NumThreads"] / d["WavefrontSize"])))
610610

611611
return cls(waveNum = d['NumThreads'] // d['WavefrontSize'],
612612
workGroup = d['WorkGroup'],

tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,10 @@ namespace TensileLite
230230
virtual bool operator()(ContractionProblemGemm const& problem) const override
231231
{
232232
// WorkGroup numbers x number of global write instruction x Wave numbers
233-
// M/MT0 x N/MT1 x NumElementsPerThread/StoreVectorWidth x x Wavenumbers
233+
// M/MT0 x N/MT1 x NumElementsPerThread/StoreVectorWidth x x Wavenumbers x batch
234234
bool ret = (std::ceil(static_cast<float>(problem.freeSizeA(0)) / value[0])
235235
* std::ceil(static_cast<float>(problem.freeSizeB(0)) / value[1]))
236-
* (value[2]) * (value[4] / 64) * value[3]
236+
* value[2] * value[4] * value[3] * problem.d().sizes()[2]
237237
<= 409600;
238238
if(problem.groupedGemm())
239239
ret = ret && (problem.groupedGemmCount() <= 16);

0 commit comments

Comments
 (0)