1313// ===----------------------------------------------------------------------===//
1414
1515#include " RISCVDotprodSplitter.h"
16+ #include " llvm/ADT/SmallPtrSet.h"
1617#include " llvm/ADT/SmallVector.h"
1718#include " llvm/ADT/Statistic.h"
1819#include " llvm/Analysis/AssumptionCache.h"
@@ -675,6 +676,46 @@ static bool isValidDotProductMultiply(Instruction &MulInst, Loop *L,
675676 return true ;
676677}
677678
679+ // / Inner loops produced by Clang for stepped image/filter dot products often
680+ // / use \c mul(iv, invariant_step) for GEP indices and a separate integer MAC
681+ // / (\c mul(add(...), ...) -> widen -> add to i64 phi). That shape does not
682+ // / satisfy \c isValidDotProductMultiply (two load operands). For extraction we
683+ // / only need a recognizable MAC plus multiple affine loads in the same loop.
684+ static bool hasMacMulWithAffineLoadsPattern (Loop *L, ScalarEvolution &SE) {
685+ bool HasMacMul = false ;
686+ unsigned AffineLoadCount = 0 ;
687+ for (BasicBlock *BB : L->getBlocks ()) {
688+ for (Instruction &I : *BB) {
689+ if (I.getOpcode () == Instruction::Mul && hasAccumulationPattern (I)) {
690+ HasMacMul = true ;
691+ LLVM_DEBUG (dbgs () << " Found MAC-style multiply for extraction gate: "
692+ << I << " \n " );
693+ }
694+ auto *Ld = dyn_cast<LoadInst>(&I);
695+ if (!Ld)
696+ continue ;
697+ if (!isSimpleForwardAccess (Ld, L, SE))
698+ continue ;
699+ ++AffineLoadCount;
700+ LLVM_DEBUG (dbgs () << " Affine sequential load: " << *Ld << " \n " );
701+ }
702+ }
703+ if (!HasMacMul) {
704+ LLVM_DEBUG (dbgs () << " No mul with accumulation chain in loop\n " );
705+ return false ;
706+ }
707+
708+ if (AffineLoadCount < 2 ) {
709+ LLVM_DEBUG (dbgs () << " MAC gate: need >= 2 affine loads, have "
710+ << AffineLoadCount << " \n " );
711+ return false ;
712+ }
713+
714+ LLVM_DEBUG (dbgs () << " Matched Clang-style dotprod loop (MAC + "
715+ << AffineLoadCount << " affine loads)\n " );
716+ return true ;
717+ }
718+
678719// Check for multiply-accumulate pattern (refactored for better readability)
679720static bool hasMultiplyAccumulatePattern (Loop *L, ScalarEvolution &SE) {
680721 LLVM_DEBUG (
@@ -691,7 +732,11 @@ static bool hasMultiplyAccumulatePattern(Loop *L, ScalarEvolution &SE) {
691732 }
692733
693734 LLVM_DEBUG (dbgs () << " No standard pattern found, checking offset pattern\n " );
694- return hasOffsetDotProductPattern (L);
735+ if (hasOffsetDotProductPattern (L))
736+ return true ;
737+
738+ LLVM_DEBUG (dbgs () << " Checking Clang-style MAC + affine loads pattern\n " );
739+ return hasMacMulWithAffineLoadsPattern (L, SE);
695740}
696741
697742// Conditional LoopExtractor implementation
@@ -985,15 +1030,21 @@ static bool hasNestedLoopsWithProcessablePatterns(Function &F) {
9851030
9861031 ScalarEvolution SE (F, TLI, AC, DT, LI);
9871032
988- // Check nested loops for multiply-accumulate patterns
989- for (Loop *L : LI.getLoopsInPreorder ()) {
990- if (!L->getSubLoops ().empty ()) {
991- // Has nested loops, check if inner loops have multiply-accumulate pattern
992- for (Loop *InnerL : L->getSubLoops ()) {
993- if (hasMultiplyAccumulatePattern (InnerL, SE)) {
994- return true ;
995- }
996- }
1033+ // Walk every inner loop (descendant of a top-level loop) once; deeper nests
1034+ // may hold the per-x dot body.
1035+ SmallPtrSet<Loop *, 16 > Seen;
1036+ SmallVector<Loop *, 8 > Stack;
1037+ for (Loop *Top : LI) {
1038+ for (Loop *SL : Top->getSubLoops ())
1039+ Stack.push_back (SL);
1040+ while (!Stack.empty ()) {
1041+ Loop *Cur = Stack.pop_back_val ();
1042+ if (!Seen.insert (Cur).second )
1043+ continue ;
1044+ if (hasMultiplyAccumulatePattern (Cur, SE))
1045+ return true ;
1046+ for (Loop *Child : Cur->getSubLoops ())
1047+ Stack.push_back (Child);
9971048 }
9981049 }
9991050
0 commit comments