Skip to content

Commit e06b185

Browse files
GEMM: bump MR to min(16,M) for skinny-N (n<=16) BF16 and F32 shapes (#524)
The default ZMM DE returns mr=6, nr=64, so for n<=16 shapes the JIT only reaches the lt16-mask kernel (bReg=1) and 6 of the 32 ZMMs hold C accumulators while the other ~25 sit idle. Overriding mr to min(16, m) lets each cached B line be consumed by up to 16 rows of A instead of 6, recovering the otherwise-wasted register file. Change: * gemmBF16DEBackend / gemmF32DEBackend (ZMM fast path): For n<=16 and m>0, set mr = min(16, m). nr stays at nr_hint so the existing NR=64 packed-B layout, N-direction blocking, and rsB-divisor math are reused unchanged. F32 is additionally gated on !invokeRD and kc != 1. Guards added so the bumped-MR path doesn't break the rest of the kernel set: * New `skinnyN` flag on kernel_frame::kernelInfo (threaded through the ctors, copy/move, operator== and gemmDEBackendUtils::checkPostOpsAndCreateKernelInfo). Set true only at the two ZMM override sites, when the MR bump actually fires; false everywhere else. * jitAmdZenFP32 / jitAmdZenBF16 generateAllKernels honor skinnyN by skipping nr>=2. Those wider NR variants (lt32 / 32 / lt48 / 48 / lt64 / 64) are unreachable for n<=16 and exceed the 32-ZMM budget at MR=16 (especially with post-ops and column-major beta scaling), so generating them only produced badKernelInfo aborts that took down the whole kernel set. The default (n>16) path is untouched: skinnyN stays false and every NR variant continues to be generated and dispatched as before. [ AMD-Internal - SWLCSG-4250 ]
1 parent 74536a8 commit e06b185

4 files changed

Lines changed: 187 additions & 13 deletions

File tree

src/include/decision_engine/de_backend.hh

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ static const kernel_frame::kernelInfo INVALID_KERNEL_INFO{
5757
0,
5858
false,
5959
kernel_frame::kernelInstrPreference::none,
60-
0
60+
0,
61+
false
6162
};
6263

6364
class iDEBackend
@@ -387,11 +388,43 @@ class gemmF32DEBackend : public iDEBackend
387388
k_unroll = 4;
388389
}
389390

391+
// Increasing MR helps when n<=16 (mirrors the BF16 fix below).
392+
//
393+
// The default DE returns mr=6, nr=64. For shapes with n<=16 the
394+
// dispatcher only ever reaches the NR=16 family of kernels (the
395+
// lt16-mask kernel for n<16 and the NR=16 full kernel for n==16),
396+
// so most of the 32 ZMMs sit idle as C accumulators. Bumping mr
397+
// lets each cached B line be consumed by more rows of A.
398+
//
399+
// nr stays at nr_hint (=64) so the existing row-major NR=64
400+
// packed-B layout and the framework's N-direction blocking are
401+
// reused unchanged. The JIT generator below will see the bumped
402+
// MR and skip the wider NR variants whose register budget would
403+
// overflow at MR=16 (NR>=32 needs cReg>=32); those slots are
404+
// never reached at runtime for n<=16 anyway.
405+
//
406+
// Override only when:
407+
// - n <= 16 (skinny-N: only NR=16 family reached)
408+
// - !invokeRD (RD path has its own internal MR)
409+
// - kInstPref is ZMM (32-ZMM budget is what the math relies on;
410+
// AVX2 path has 16 regs, MR=16 won't fit)
411+
// - kc != 1 (the k=1 fused path is sized differently)
412+
//
413+
// For M < 16 we cap mr at m so the kernel uses an MR-partial
414+
// kernel sized exactly to the input row count.
415+
bool skinnyN = false;
416+
if (!invokeRD && n <= 16 && m > 0 && kc != 1
417+
&& kInstPref
418+
== kernel_frame::kernelInstrPreference::avx512_zmm_favour) {
419+
mr = (m < 16) ? m : 16;
420+
skinnyN = true;
421+
}
422+
390423
return gemmDEBackendUtils::checkPostOpsAndCreateKernelInfo(
391424
mr, nr, 0, k_unroll, kc, prefetch_c_dist, alphaScalingType,
392425
betaScalingType, mtag_a, mtag_b, allLtFringeKernels, invokeRD,
393-
anyKOpsOrder, kInstPref, c_downscale, k_dtype, rs_c, cs_c,
394-
metadata);
426+
anyKOpsOrder, kInstPref, c_downscale, k_dtype, rs_c, cs_c, metadata,
427+
skinnyN);
395428
}
396429

397430
DLP_ALWAYS_INLINE
@@ -567,8 +600,40 @@ class gemmBF16DEBackend : public iDEBackend
567600
std::tie(alphaScalingType, betaScalingType) =
568601
gemmDEBackendUtils::getScalingTypes<float>(alpha, beta, k, kc_hint);
569602

570-
md_t mr = mr_hint;
571-
md_t nr = nr_hint;
603+
md_t mr = mr_hint;
604+
md_t nr = nr_hint;
605+
606+
// Increasing MR helps when n<=16.
607+
//
608+
// The default DE returns mr=6, nr=64. For shapes with n<=16 the
609+
// dispatcher only ever reaches the NR=16 family of kernels (the
610+
// lt16-mask kernel for n<16 and the NR=16 full kernel for n==16),
611+
// so only 6 of the 32 ZMM registers are used as C accumulators --
612+
// the other ~25 ZMMs sit idle.
613+
//
614+
// We bump mr to min(16, M) so each cached B line is now consumed
615+
// by up to 16 rows of A instead of 6, raising B reuse and cutting
616+
// the M-iteration count from ceil(M/6) to ceil(M/16). With
617+
// bReg=1 (the only NR variant the skinny-N dispatch reaches),
618+
// cReg=16 and aReg=15: well inside the 32-ZMM budget.
619+
//
620+
// nr stays at nr_hint (=64) so the existing row-major NR=64
621+
// packed-B layout and the framework's N-direction blocking are
622+
// reused unchanged. The JIT generator below skips the wider NR
623+
// variants whose register budget would overflow at MR=16
624+
// (NR>=32 needs cReg>=32); those slots are never reached at
625+
// runtime for n<=16 anyway.
626+
//
627+
// For M < 16 we cap mr at m so the kernel uses an MR-partial
628+
// kernel sized exactly to the input row count (single full
629+
// panel, no fringe). This avoids leaving C ZMMs idle for tiny-M
630+
// shapes.
631+
bool skinnyN = false;
632+
if (n <= 16 && m > 0) {
633+
mr = (m < 16) ? m : 16;
634+
skinnyN = true;
635+
}
636+
572637
md_t k_unroll = 1;
573638
md_t kc = kc_hint;
574639
md_t prefetch_c_dist = getPrefetchDistance();
@@ -582,7 +647,7 @@ class gemmBF16DEBackend : public iDEBackend
582647
return gemmDEBackendUtils::checkPostOpsAndCreateKernelInfo(
583648
mr, nr, 0, k_unroll, kc, prefetch_c_dist, alphaScalingType,
584649
betaScalingType, mtag_a, mtag_b, false, false, anyKOpsOrder,
585-
kInstPref, c_downscale, k_dtype, rs_c, cs_c, metadata);
650+
kInstPref, c_downscale, k_dtype, rs_c, cs_c, metadata, skinnyN);
586651
}
587652
};
588653

src/include/decision_engine/de_backend_utils.hh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ class gemmDEBackendUtils
156156
dlp::kernel_frame::kernelDatatype k_dtype,
157157
[[maybe_unused]] md_t rs_c,
158158
[[maybe_unused]] md_t cs_c,
159-
dlp_gemm_post_op* metadata)
159+
dlp_gemm_post_op* metadata,
160+
bool skinnyN = false)
160161
{
161162
// Iterate over the post_ops list to get the number of post-ops.
162163
md_t numPostOps = 0;
@@ -184,7 +185,8 @@ class gemmDEBackendUtils
184185
0,
185186
anyKOpsOrder,
186187
kInstPref,
187-
c_downscale };
188+
c_downscale,
189+
skinnyN };
188190
} else {
189191
kernel_frame::kernelInfo kI{ mr,
190192
nr,
@@ -202,7 +204,8 @@ class gemmDEBackendUtils
202204
0,
203205
anyKOpsOrder,
204206
kInstPref,
205-
c_downscale };
207+
c_downscale,
208+
skinnyN };
206209
kI.kOpsArrSize = numPostOps;
207210
kI.kOpsArr =
208211
kernel_frame::kernelInfo::allocateKernelOpsArray(numPostOps);

src/include/kernel_frame/kernel_frame_base.hh

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ struct kernelInfo
304304
bool anyKOpsOrder;
305305
kernelInstrPreference kInstPref;
306306
md_t c_downscale;
307+
// True when the DE has applied the skinny-N (n<=16) override that
308+
// bumps MR above the default. Signals to the JIT generator that
309+
// only the lt16 (nr=0) and full-16 (nr=1) N-direction variants will
310+
// ever be invoked at runtime, so the wider lt32/32/lt48/48/lt64/64
311+
// slots can be skipped during generation.
312+
bool skinnyN;
307313

308314
// Empty constructor to create dummy kernelInfo.
309315
kernelInfo()
@@ -323,6 +329,7 @@ struct kernelInfo
323329
, anyKOpsOrder(false)
324330
, kInstPref(kernel_frame::kernelInstrPreference::none)
325331
, c_downscale(0)
332+
, skinnyN(false)
326333
{
327334
}
328335

@@ -342,7 +349,8 @@ struct kernelInfo
342349
std::size_t kOpsArrSize,
343350
bool anyKOpsOrder,
344351
kernelInstrPreference instPref,
345-
md_t c_downscale)
352+
md_t c_downscale,
353+
bool _skinnyN = false)
346354
: mr(mr)
347355
, nr(nr)
348356
, term_fringe_nr(_term_fringe_nr)
@@ -363,6 +371,7 @@ struct kernelInfo
363371
, anyKOpsOrder(anyKOpsOrder)
364372
, kInstPref(instPref)
365373
, c_downscale(c_downscale)
374+
, skinnyN(_skinnyN)
366375
{
367376
}
368377

@@ -386,6 +395,7 @@ struct kernelInfo
386395
, anyKOpsOrder(other.anyKOpsOrder)
387396
, kInstPref(other.kInstPref)
388397
, c_downscale(other.c_downscale)
398+
, skinnyN(other.skinnyN)
389399
{
390400
if ((other.kOpsArr != nullptr) && (other.kOpsArrSize > 0)) {
391401
this->kOpsArr =
@@ -419,6 +429,7 @@ struct kernelInfo
419429
, anyKOpsOrder(other->anyKOpsOrder)
420430
, kInstPref(other->kInstPref)
421431
, c_downscale(other->c_downscale)
432+
, skinnyN(other->skinnyN)
422433
{
423434
if ((other->kOpsArr != nullptr) && (other->kOpsArrSize > 0)) {
424435
other->kOpsArr = nullptr;
@@ -448,6 +459,7 @@ struct kernelInfo
448459
, anyKOpsOrder(other.anyKOpsOrder)
449460
, kInstPref(other.kInstPref)
450461
, c_downscale(other.c_downscale)
462+
, skinnyN(other.skinnyN)
451463
{
452464
if ((other.kOpsArr != nullptr) && (other.kOpsArrSize > 0)) {
453465
other.kOpsArr = nullptr;
@@ -485,6 +497,7 @@ struct kernelInfo
485497
this->anyKOpsOrder = other.anyKOpsOrder;
486498
this->kInstPref = other.kInstPref;
487499
this->c_downscale = other.c_downscale;
500+
this->skinnyN = other.skinnyN;
488501
}
489502
return *this;
490503
}
@@ -517,6 +530,7 @@ struct kernelInfo
517530
this->anyKOpsOrder = other.anyKOpsOrder;
518531
this->kInstPref = other.kInstPref;
519532
this->c_downscale = other.c_downscale;
533+
this->skinnyN = other.skinnyN;
520534
}
521535
return *this;
522536
}
@@ -552,7 +566,8 @@ struct kernelInfo
552566
&& (this->kOpsArrSize == rhs.kOpsArrSize) && isKOpsArrEqual
553567
&& (this->anyKOpsOrder == rhs.anyKOpsOrder)
554568
&& (this->kInstPref == rhs.kInstPref)
555-
&& (this->c_downscale == rhs.c_downscale));
569+
&& (this->c_downscale == rhs.c_downscale)
570+
&& (this->skinnyN == rhs.skinnyN));
556571
}
557572

558573
// TODO: Need to implement a subset function for kernelInfo

src/jit/amdzen/amdzen_generator.cc

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,22 @@ jitAmdZenFP32::generateAllKernels(const dlp::jit::jitGeneratorContext& jI)
602602
params.kernelOps.push_back((jI.kI).kOpsArr[ii]);
603603
}
604604

605-
// Generate all kernels for the given MR and NR
605+
// Generate all kernels for the given MR and NR. Any per-variant
606+
// generator failure (after passing the feasibility filter below)
607+
// is fatal (goto cleanup): we want the existing fail-fast contract
608+
// to hold, so we never call generateKernel() for an (mr, nr) pair
609+
// we already know cannot fit.
610+
//
611+
// Feasibility filter: when the DE bumps MR (e.g. MR=16 for the
612+
// skinny-N n<=16 override), the wider-NR variants (NR>=32) would
613+
// exceed the register budget (cReg=MR*bReg, aReg = numRegs -
614+
// cReg - bReg - maskVecReg, must have aReg >= 1 -- mirrors
615+
// jitGEMMF32::allocateReg()). Skip those slots up front instead
616+
// of relying on the per-variant generator to return badKernelInfo.
617+
// The dispatcher only reaches the lt-mask kernel and the NR=16
618+
// full kernel for n<=16, both of which always pass the filter.
619+
const int kNumRegs =
620+
(kType == utils::kernelInstrType::avx2_ymm_16_reg) ? 16 : 32;
606621
for (iter_t mr = 0; mr < numMRVariants; mr++) {
607622
for (iter_t nr = 0; nr < numNRVariants; nr++) {
608623
params.MR = mr == 0 ? MR : mr;
@@ -611,6 +626,42 @@ jitAmdZenFP32::generateAllKernels(const dlp::jit::jitGeneratorContext& jI)
611626
int correspondingMainFringe = 0;
612627
deriveGEMMNRAndMaskUse(nr, params, correspondingMainFringe);
613628

629+
// Skinny-N override: when the DE has bumped MR via the
630+
// n<=16 override (jI.kI.skinnyN), only the lt-numElems
631+
// (nr=0) and the full numElems (nr=1) variants are ever
632+
// dispatched at runtime. The wider NR slots (nr>=2:
633+
// lt2x/2x/lt3x/3x/lt4x/4x) are unreachable AND would
634+
// exceed the register budget at bumped MR -- skip them
635+
// entirely so we don't waste codegen on dead kernels
636+
// (and don't trigger badKernelInfo on infeasible ones).
637+
if (jI.kI.skinnyN && nr >= 2) {
638+
continue;
639+
}
640+
641+
// Pre-filter register-infeasible (MR, NR) variants. This
642+
// mirrors jitGEMMF32<KType>::allocateReg(): bFullReg =
643+
// NR / numElemsPerReg, bMaskReg = useMask ? numMaskRegs
644+
// : 0, bReg = bFullReg + bMaskReg, cReg = MR * bReg.
645+
// For AVX2 ymm, the mask consumes vector registers
646+
// (maskVecReg = numMaskRegs); for AVX-512 the mask is in
647+
// Opmask regs and does not draw from the vector budget.
648+
{
649+
int bFullReg = params.NR / numElemsPerReg;
650+
int bMaskReg = params.useMask ? params.numMaskRegs : 0;
651+
int bReg = bFullReg + bMaskReg;
652+
int cReg = params.MR * bReg;
653+
int maskVecReg =
654+
(kType == utils::kernelInstrType::avx2_ymm_16_reg)
655+
? bMaskReg
656+
: 0;
657+
if (kNumRegs - cReg - bReg - maskVecReg < 1) {
658+
// Slot stays nullptr (zero-initialized by
659+
// resize). The dispatcher never reaches it for
660+
// any DE-blessed shape that bumped MR.
661+
continue;
662+
}
663+
}
664+
614665
std::unique_ptr<Xbyak::CodeGenerator> gen;
615666
switch (kType) {
616667
case utils::kernelInstrType::avx512_zmm_32_reg: {
@@ -1414,14 +1465,54 @@ jitAmdZenBF16::generateAllKernels(const dlp::jit::jitGeneratorContext& jI)
14141465
params.kernelOps.push_back((jI.kI).kOpsArr[ii]);
14151466
}
14161467

1417-
// Generate all kernels for the given MR and NR
1468+
// Generate all kernels for the given MR and NR. Any per-variant
1469+
// generator failure (after passing the feasibility filter below)
1470+
// is fatal (goto cleanup): we want the existing fail-fast contract
1471+
// to hold, so we never call generateKernel() for an (mr, nr) pair
1472+
// we already know cannot fit.
1473+
//
1474+
// Feasibility filter: when the DE bumps MR (e.g. MR=16 for the
1475+
// skinny-N n<=16 override), the wider-NR variants (NR>=32) would
1476+
// exceed the 32-ZMM budget (cReg=MR*bReg, must have aReg = 32 -
1477+
// cReg - bReg >= aRegMin -- mirrors jitGEMMBF16::allocateReg()).
1478+
// Skip those slots up front instead of relying on the per-variant
1479+
// generator to return badKernelInfo. The dispatcher only reaches
1480+
// the lt16-mask kernel and the NR=16 full kernel for n<=16, both
1481+
// of which always pass the filter.
1482+
constexpr int kZmmRegs = 32;
1483+
const int aRegMin = ((jI.kI).c_downscale < DLP_F32) ? 2 : 1;
14181484
for (iter_t mr = 0; mr < numMRVariants; mr++) {
14191485
for (iter_t nr = 0; nr < numNRVariants; nr++) {
14201486
params.MR = (mr == 0) ? MR : mr;
14211487
params.mLoop = (mr == 0);
14221488
params.NR = (nr * numElemsPerReg);
14231489
params.useMask = (nr == 0);
14241490
params.numMaskRegs = (params.useMask) ? 1 : 0;
1491+
1492+
// Skinny-N override: when the DE has bumped MR via the
1493+
// n<=16 override (jI.kI.skinnyN), only the lt16 (nr=0)
1494+
// and full-16 (nr=1) variants are ever dispatched. The
1495+
// wider NR slots (nr>=2: lt32/32/lt48/48/lt64/64) are
1496+
// unreachable AND would exceed the 32-ZMM budget at
1497+
// bumped MR -- skip them entirely so we don't waste
1498+
// codegen on dead kernels (and don't trigger
1499+
// badKernelInfo on infeasible ones).
1500+
if (jI.kI.skinnyN && nr >= 2) {
1501+
continue;
1502+
}
1503+
1504+
// For BF16 ZMM: bFullReg = (2*NR)/nBF16ElemsPerReg
1505+
// = NR/16 = nr (with numElemsPerReg=16). bMaskReg=1 for
1506+
// useMask, else 0. So bReg = max(1, nr).
1507+
int bReg = (nr == 0) ? 1 : static_cast<int>(nr);
1508+
int cReg = params.MR * bReg;
1509+
if (kZmmRegs - cReg - bReg < aRegMin) {
1510+
// Slot stays nullptr (zero-initialized by resize).
1511+
// The dispatcher never reaches it for any DE-blessed
1512+
// shape that bumped MR.
1513+
continue;
1514+
}
1515+
14251516
auto gen = std::make_unique<GEMMcodeGenerator::jitGEMMBF16<
14261517
utils::kernelInstrType::avx512_zmm_32_reg>>(
14271518
utils::JIT_KERNEL_SIZE);

0 commit comments

Comments
 (0)