Skip to content

Commit 6f82eb4

Browse files
committed
Mitigate L1 cache aliasing issue in fp32 GEMV JIT kernel
When MR > 8 (e.g., AVX-512 with MR=16), row strides near multiples of 4096 bytes cause all rows to map to the same L1 cache set, exceeding 8-way associativity and causing thrashing. Add runtime detection of aliasing strides and a two-pass k-loop fallback that processes MR/2 rows per pass to stay within L1 associativity limits.
1 parent 28c54e7 commit 6f82eb4

1 file changed

Lines changed: 108 additions & 0 deletions

File tree

src/jit/amdzen/f32_gemv_generator.cc

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,30 @@ jitF32GEMVN1<KType>::generateKernel(utils::gemvN1GeneratorParams& params)
14071407

14081408
// Set the for-loop sequence for k-dimension
14091409
if (params.kloop) {
1410+
// Cache aliasing mitigation: when row stride is near a
1411+
// multiple of 4096 bytes (L1 cache modulus on Zen4),
1412+
// all 16 rows (MR=16, AVX-512) map to the same L1 cache
1413+
// set, exceeding its 8-way associativity and causing
1414+
// thrashing. Detect this at runtime and split the k-loop
1415+
// into two passes of MR/2=8 rows each.
1416+
Xbyak::Label label_alias_path;
1417+
bool enableAliasMitigation = (MR > 8);
1418+
1419+
if (enableAliasMitigation) {
1420+
// Aliasing detection: cache set conflicts occur when
1421+
// rsA is near a multiple of 4096 bytes (L1 cache
1422+
// modulus on Zen4, 64 sets × 64B lines = 4096B).
1423+
// With MR=16 rows, aliasing occurs when 8+ rows
1424+
// map to the same L1 cache set (8-way associative).
1425+
mov(regTmp2, regRsA);
1426+
and_(regTmp2, 0xFFF); // rsA mod 4096 (bytes)
1427+
cmp(regTmp2, 64);
1428+
jb(label_alias_path, T_NEAR);
1429+
cmp(regTmp2, 4096 - 64);
1430+
ja(label_alias_path, T_NEAR);
1431+
}
1432+
1433+
// --- Normal path (non-aliasing strides) ---
14101434
mov(regKIter,
14111435
ptr[stackPtr
14121436
+ offsetof(dlp::kernels::gemvN1Params, k_iter)]);
@@ -1430,6 +1454,90 @@ jitF32GEMVN1<KType>::generateKernel(utils::gemvN1GeneratorParams& params)
14301454

14311455
sub(regKIter, 1);
14321456
jnz(label_m_loop_k_loop_start, T_NEAR);
1457+
1458+
if (enableAliasMitigation) {
1459+
jmp(label_m_loop_k_loop_end, T_NEAR);
1460+
1461+
// --- Alias path (two-pass, 8 rows per pass) ---
1462+
L(label_alias_path);
1463+
1464+
// Pass 1: rows 0..MR/2-1
1465+
{
1466+
Xbyak::Label lp1_start, lp1_end;
1467+
mov(regKIter,
1468+
ptr[stackPtr
1469+
+ offsetof(
1470+
dlp::kernels::gemvN1Params, k_iter)]);
1471+
test(regKIter, regKIter);
1472+
jz(lp1_end, T_NEAR);
1473+
L(lp1_start);
1474+
1475+
RETURN_IF_ERROR((loadXValues()));
1476+
RETURN_IF_ERROR((processMRBlock(MR / 2)));
1477+
1478+
mov(regTmp1, simdWidth);
1479+
imul(regTmp1, regCsA);
1480+
add(regTmpYptr, regTmp1);
1481+
mov(regTmpAptr, regTmpYptr);
1482+
add(regXptr, RegBytes);
1483+
1484+
sub(regKIter, 1);
1485+
jnz(lp1_start, T_NEAR);
1486+
L(lp1_end);
1487+
}
1488+
1489+
// Save pass-1 column position in regTmp2 for restore
1490+
// after pass 2
1491+
mov(regTmp2, regTmpYptr);
1492+
1493+
// Setup pass 2: start from row MR/2
1494+
mov(regTmpYptr, regAptr);
1495+
mov(regTmp1, MR / 2);
1496+
imul(regTmp1, regRsA);
1497+
add(regTmpYptr, regTmp1);
1498+
mov(regTmpAptr, regTmpYptr);
1499+
1500+
// Reset X pointer for pass 2
1501+
mov(regXptr,
1502+
ptr[stackPtr
1503+
+ offsetof(dlp::kernels::gemvN1Params, x)]);
1504+
1505+
// Pass 2: rows MR/2..MR-1
1506+
{
1507+
int savedAccumBase = accumBaseIdx;
1508+
accumBaseIdx += MR / 2;
1509+
1510+
Xbyak::Label lp2_start, lp2_end;
1511+
mov(regKIter,
1512+
ptr[stackPtr
1513+
+ offsetof(
1514+
dlp::kernels::gemvN1Params, k_iter)]);
1515+
test(regKIter, regKIter);
1516+
jz(lp2_end, T_NEAR);
1517+
L(lp2_start);
1518+
1519+
RETURN_IF_ERROR((loadXValues()));
1520+
RETURN_IF_ERROR((processMRBlock(MR / 2)));
1521+
1522+
mov(regTmp1, simdWidth);
1523+
imul(regTmp1, regCsA);
1524+
add(regTmpYptr, regTmp1);
1525+
mov(regTmpAptr, regTmpYptr);
1526+
add(regXptr, RegBytes);
1527+
1528+
sub(regKIter, 1);
1529+
jnz(lp2_start, T_NEAR);
1530+
L(lp2_end);
1531+
1532+
accumBaseIdx = savedAccumBase;
1533+
}
1534+
1535+
// Restore A column position for shared k-fringe
1536+
// (regXptr is already correct: both passes iterate
1537+
// k_iter times, ending at X_base + k_iter*RegBytes)
1538+
mov(regTmpYptr, regTmp2);
1539+
mov(regTmpAptr, regTmpYptr);
1540+
}
14331541
}
14341542
L(label_m_loop_k_loop_end);
14351543

0 commit comments

Comments
 (0)