Skip to content

Commit c74b665

Browse files
chenqiancursoragent
andcommitted
[RISCV] Improve dotprod splitter for Clang MAC loops and nested inner loops
Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 39a5993 commit c74b665

11 files changed

Lines changed: 432 additions & 741 deletions

llvm/lib/Target/RISCV/RISCVDotprodSplitter.cpp

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "RISCVDotprodSplitter.h"
16+
#include "llvm/ADT/SmallPtrSet.h"
1617
#include "llvm/ADT/SmallVector.h"
1718
#include "llvm/ADT/Statistic.h"
1819
#include "llvm/Analysis/AssumptionCache.h"
@@ -675,6 +676,46 @@ static bool isValidDotProductMultiply(Instruction &MulInst, Loop *L,
675676
return true;
676677
}
677678

679+
/// Inner loops produced by Clang for stepped image/filter dot products often
680+
/// use \c mul(iv, invariant_step) for GEP indices and a separate integer MAC
681+
/// (\c mul(add(...), ...) -> widen -> add to i64 phi). That shape does not
682+
/// satisfy \c isValidDotProductMultiply (two load operands). For extraction we
683+
/// only need a recognizable MAC plus multiple affine loads in the same loop.
684+
static bool hasMacMulWithAffineLoadsPattern(Loop *L, ScalarEvolution &SE) {
685+
bool HasMacMul = false;
686+
unsigned AffineLoadCount = 0;
687+
for (BasicBlock *BB : L->getBlocks()) {
688+
for (Instruction &I : *BB) {
689+
if (I.getOpcode() == Instruction::Mul && hasAccumulationPattern(I)) {
690+
HasMacMul = true;
691+
LLVM_DEBUG(dbgs() << "Found MAC-style multiply for extraction gate: "
692+
<< I << "\n");
693+
}
694+
auto *Ld = dyn_cast<LoadInst>(&I);
695+
if (!Ld)
696+
continue;
697+
if (!isSimpleForwardAccess(Ld, L, SE))
698+
continue;
699+
++AffineLoadCount;
700+
LLVM_DEBUG(dbgs() << " Affine sequential load: " << *Ld << "\n");
701+
}
702+
}
703+
if (!HasMacMul) {
704+
LLVM_DEBUG(dbgs() << "No mul with accumulation chain in loop\n");
705+
return false;
706+
}
707+
708+
if (AffineLoadCount < 2) {
709+
LLVM_DEBUG(dbgs() << "MAC gate: need >= 2 affine loads, have "
710+
<< AffineLoadCount << "\n");
711+
return false;
712+
}
713+
714+
LLVM_DEBUG(dbgs() << "Matched Clang-style dotprod loop (MAC + "
715+
<< AffineLoadCount << " affine loads)\n");
716+
return true;
717+
}
718+
678719
// Check for multiply-accumulate pattern (refactored for better readability)
679720
static bool hasMultiplyAccumulatePattern(Loop *L, ScalarEvolution &SE) {
680721
LLVM_DEBUG(
@@ -691,7 +732,11 @@ static bool hasMultiplyAccumulatePattern(Loop *L, ScalarEvolution &SE) {
691732
}
692733

693734
LLVM_DEBUG(dbgs() << "No standard pattern found, checking offset pattern\n");
694-
return hasOffsetDotProductPattern(L);
735+
if (hasOffsetDotProductPattern(L))
736+
return true;
737+
738+
LLVM_DEBUG(dbgs() << "Checking Clang-style MAC + affine loads pattern\n");
739+
return hasMacMulWithAffineLoadsPattern(L, SE);
695740
}
696741

697742
// Conditional LoopExtractor implementation
@@ -985,15 +1030,21 @@ static bool hasNestedLoopsWithProcessablePatterns(Function &F) {
9851030

9861031
ScalarEvolution SE(F, TLI, AC, DT, LI);
9871032

988-
// Check nested loops for multiply-accumulate patterns
989-
for (Loop *L : LI.getLoopsInPreorder()) {
990-
if (!L->getSubLoops().empty()) {
991-
// Has nested loops, check if inner loops have multiply-accumulate pattern
992-
for (Loop *InnerL : L->getSubLoops()) {
993-
if (hasMultiplyAccumulatePattern(InnerL, SE)) {
994-
return true;
995-
}
996-
}
1033+
// Walk every inner loop (descendant of a top-level loop) once; deeper nests
1034+
// may hold the per-x dot body.
1035+
SmallPtrSet<Loop *, 16> Seen;
1036+
SmallVector<Loop *, 8> Stack;
1037+
for (Loop *Top : LI) {
1038+
for (Loop *SL : Top->getSubLoops())
1039+
Stack.push_back(SL);
1040+
while (!Stack.empty()) {
1041+
Loop *Cur = Stack.pop_back_val();
1042+
if (!Seen.insert(Cur).second)
1043+
continue;
1044+
if (hasMultiplyAccumulatePattern(Cur, SE))
1045+
return true;
1046+
for (Loop *Child : Cur->getSubLoops())
1047+
Stack.push_back(Child);
9971048
}
9981049
}
9991050

llvm/lib/Target/RISCV/RISCVDotprodSplitter.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file declares the RISCVDotprodSplitterPass class.
10-
// This pass identifies a specific pattern often associated with calls to inner
11-
// dot product computation functions, where the result is passed via a pointer
12-
// argument (typically an alloca in the caller). The pattern involves a
13-
// sequence of lifetime start, the call instruction, a load from the result
14-
// pointer, and lifetime end, all within the same basic block.
10+
// This pass identifies calls to inner dot-product helpers that pass the
11+
// running accumulator through a pointer (usually an entry-block alloca). The
12+
// expected caller shape is: optional @llvm.lifetime.start, the call, a single
13+
// reload from that slot, optional @llvm.lifetime.end, in one basic block; if
14+
// the reload was sunk to the unique successor, the pass can hoist it back.
1515
//
1616
// If this unique pattern is found, the pass restructures the control flow
1717
// graph (CFG) to create specialized paths for common constant "step" or
@@ -57,16 +57,24 @@ struct RISCVDotprodSplitterPass
5757

5858
static bool isRequired() { return true; }
5959

60-
/// Check if the function contains patterns that can be processed by this
61-
/// pass.
60+
/// True if the function has a nested loop that looks like a dot-product inner
61+
/// loop (legacy two-load \c mul, offset variant, or Clang-style MAC with
62+
/// multiple affine loads). Used by the conditional loop extractor gate and
63+
/// related tooling.
6264
static bool hasProcessablePattern(Function &F);
6365
};
6466

6567
/// Conditional LoopExtractor Pass that only runs when dotprod patterns exist.
6668
///
67-
/// This pass runs LoopExtractor only on modules that contain processable
68-
/// dot product patterns, avoiding unnecessary loop extraction on modules
69-
/// that won't benefit from the dotprod splitter optimization.
69+
/// This pass runs selective CodeExtractor-based loop extraction only on
70+
/// modules that contain processable dot product patterns (same heuristic as
71+
/// \c RISCVDotprodSplitterPass::hasProcessablePattern), avoiding work on other
72+
/// modules. Enable with the same \c -riscv-dotprod-splitter flag as the
73+
/// splitter pass.
74+
///
75+
/// Pipeline name (RISC-V target): \c riscv-dotprod-conditional-loop-extractor
76+
/// (module pass). Run before \c riscv-dotprod-splitter when extracting inner
77+
/// loops from Clang output is required.
7078
struct RISCVConditionalLoopExtractorPass
7179
: public PassInfoMixin<RISCVConditionalLoopExtractorPass> {
7280

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,15 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
719719
}
720720
return false;
721721
});
722-
722+
PB.registerPipelineParsingCallback(
723+
[](StringRef Name, ModulePassManager &MPM,
724+
ArrayRef<PassBuilder::PipelineElement>) {
725+
if (Name == "riscv-dotprod-conditional-loop-extractor") {
726+
MPM.addPass(RISCVConditionalLoopExtractorPass());
727+
return true;
728+
}
729+
return false;
730+
});
723731
PB.registerOptimizerLastEPCallback([](ModulePassManager &PM,
724732
OptimizationLevel Level,
725733
ThinOrFullLTOPhase Phase) {

0 commit comments

Comments
 (0)