Skip to content

Commit 85dd208

Browse files
authored
[HAL] Add pass pipeline caching to executable translation and configuration (#23643)
Adds a PipelineCache utility (in Utils/PassUtils.h) that caches constructed OpPassManagers keyed by ExecutableTargetAttr. When multiple executable variants share the same target attribute, the pass pipeline only needs to be constructed once rather than being rebuilt from scratch for every variant. The cache is shared across clones of the outer per-ExecutableOp pass via shared_ptr, so MLIR's parallel pass execution across different ExecutableOps all benefit from the same cached pipelines. getOrCreate() returns a deep copy of the cached pipeline rather than a reference because MLIR passes carry mutable state (analysis caches, statistics) that is modified during execution. Since the outer per-ExecutableOp passes run in parallel, two threads processing different executables with the same target attribute would race on a shared OpPassManager. The copy cost is negligible compared to pipeline execution; the savings come from avoiding redundant pipeline construction (registry lookups, dynamic pass creation) for every variant. A quick local benchmark with varying dispatch counts shows substantial savings: ``` ┌────────────────┬──────────┬───────────┬────────┐ │ Input │ Ref (ms) │ Feat (ms) │ Delta │ ├────────────────┼──────────┼───────────┼────────┤ │ 50 dispatches │ 869 │ 524 │ -39.7% │ ├────────────────┼──────────┼───────────┼────────┤ │ 200 dispatches │ 1,084 │ 725 │ -33.1% │ ├────────────────┼──────────┼───────────┼────────┤ │ 500 dispatches │ 2,129 │ 1,684 │ -20.9% │ ├────────────────┼──────────┼───────────┼────────┤ │ TOTAL │ 4,082 │ 2,933 │ -28.1% │ └────────────────┴──────────┴───────────┴────────┘ ``` One last note, in theory we could skip this caching altogether by forcing the target backend to construct the pass pipeline on initialization and return the OpPassManager directly rather than constructing it on the fly, but that's a more substantial refactor.
1 parent 61e8c44 commit 85dd208

3 files changed

Lines changed: 109 additions & 10 deletions

File tree

compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
1313
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
1414
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
15+
#include "iree/compiler/Utils/PassUtils.h"
1516
#include "iree/compiler/Utils/TracingUtils.h"
1617
#include "llvm/ADT/StringSet.h"
17-
#include "mlir/IR/Attributes.h"
1818
#include "mlir/IR/Builders.h"
1919
#include "mlir/IR/Diagnostics.h"
2020
#include "mlir/Pass/Pass.h"
@@ -39,6 +39,18 @@ class ConfigureTargetExecutableVariantsPass
3939
ConfigureTargetExecutableVariantsPass>::
4040
ConfigureTargetExecutableVariantsPassBase;
4141

42+
public:
43+
// Constructor that also accepts a shared pipeline cache.
44+
ConfigureTargetExecutableVariantsPass(
45+
ConfigureTargetExecutableVariantsPassOptions options,
46+
std::shared_ptr<PipelineCache> cache)
47+
: ConfigureTargetExecutableVariantsPassBase(std::move(options)),
48+
pipelineCache(std::move(cache)) {}
49+
50+
private:
51+
// Shared across clones of this pass for thread-safe pipeline caching.
52+
std::shared_ptr<PipelineCache> pipelineCache;
53+
4254
void getDependentDialects(DialectRegistry &registry) const override {
4355
registry.insert<IREE::HAL::HALDialect>();
4456
auto targetBackend = targetRegistry->getTargetBackend(target);
@@ -59,9 +71,20 @@ class ConfigureTargetExecutableVariantsPass
5971
return signalPassFailure();
6072
}
6173

74+
// Build or retrieve the cached pass pipeline for this target attribute.
75+
// When many executables share the same target, this avoids redundantly
76+
// reconstructing the same pipeline for each one.
77+
IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTargetAttr();
6278
OpPassManager passManager(variantOp.getOperationName());
63-
targetBackend->buildConfigurationPassPipeline(variantOp.getTargetAttr(),
64-
passManager);
79+
if (pipelineCache) {
80+
passManager = pipelineCache->getOrCreate(
81+
targetAttr, variantOp.getOperationName(), [&](OpPassManager &pm) {
82+
targetBackend->buildConfigurationPassPipeline(targetAttr, pm);
83+
});
84+
} else {
85+
// Fallback for standalone pass usage (e.g., iree-opt).
86+
targetBackend->buildConfigurationPassPipeline(targetAttr, passManager);
87+
}
6588

6689
// This pipeline is optional, and the default is no passes, in which case
6790
// nothing is needed.
@@ -88,6 +111,13 @@ struct ConfigureExecutablesPass
88111
using IREE::HAL::impl::ConfigureExecutablesPassBase<
89112
ConfigureExecutablesPass>::ConfigureExecutablesPassBase;
90113

114+
// Shared across all clones of this pass for thread-safe pipeline caching.
115+
// When MLIR clones this pass for parallel execution on different
116+
// ExecutableOps, the shared_ptr is copied so all clones share the same
117+
// cache.
118+
std::shared_ptr<PipelineCache> pipelineCache =
119+
std::make_shared<PipelineCache>();
120+
91121
void getDependentDialects(DialectRegistry &registry) const override {
92122
registry.insert<IREE::HAL::HALDialect>();
93123
auto targetBackends = targetRegistry->getTargetBackends(
@@ -102,8 +132,10 @@ struct ConfigureExecutablesPass
102132
OpPassManager passManager(executableOp.getOperationName());
103133
for (const auto &targetName : gatherExecutableTargetNames(executableOp)) {
104134
passManager.addNestedPass<IREE::HAL::ExecutableVariantOp>(
105-
IREE::HAL::createConfigureTargetExecutableVariantsPass(
106-
{targetRegistry, targetName}));
135+
std::make_unique<ConfigureTargetExecutableVariantsPass>(
136+
ConfigureTargetExecutableVariantsPassOptions{targetRegistry,
137+
targetName},
138+
pipelineCache));
107139
}
108140

109141
IREE_COMPILER_TRACE_MESSAGE_DYNAMIC(INFO, executableOp.getSymName().str());

compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
1313
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
1414
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
15+
#include "iree/compiler/Utils/PassUtils.h"
1516
#include "iree/compiler/Utils/TracingUtils.h"
1617
#include "llvm/ADT/StringSet.h"
1718
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18-
#include "mlir/IR/Attributes.h"
1919
#include "mlir/IR/Builders.h"
2020
#include "mlir/IR/Diagnostics.h"
2121
#include "mlir/Pass/Pass.h"
@@ -40,6 +40,16 @@ struct TranslateTargetExecutableVariantsPass
4040
TranslateTargetExecutableVariantsPass>::
4141
TranslateTargetExecutableVariantsPassBase;
4242

43+
// Constructor that also accepts a shared pipeline cache.
44+
TranslateTargetExecutableVariantsPass(
45+
TranslateTargetExecutableVariantsPassOptions options,
46+
std::shared_ptr<PipelineCache> cache)
47+
: TranslateTargetExecutableVariantsPassBase(std::move(options)),
48+
pipelineCache(std::move(cache)) {}
49+
50+
// Shared across clones of this pass for thread-safe pipeline caching.
51+
std::shared_ptr<PipelineCache> pipelineCache;
52+
4353
void getDependentDialects(DialectRegistry &registry) const override {
4454
registry.insert<IREE::HAL::HALDialect>();
4555
registry.insert<bufferization::BufferizationDialect>();
@@ -65,9 +75,21 @@ struct TranslateTargetExecutableVariantsPass
6575
return signalPassFailure();
6676
}
6777

78+
// Build or retrieve the cached pass pipeline for this target attribute.
79+
// When many executables share the same target, this avoids redundantly
80+
// reconstructing the same pipeline for each one.
81+
IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTargetAttr();
6882
OpPassManager passManager(variantOp.getOperationName());
69-
targetBackend->buildTranslationPassPipeline(variantOp.getTargetAttr(),
70-
passManager);
83+
if (pipelineCache) {
84+
passManager = pipelineCache->getOrCreate(
85+
targetAttr, variantOp.getOperationName(), [&](OpPassManager &pm) {
86+
targetBackend->buildTranslationPassPipeline(targetAttr, pm);
87+
});
88+
} else {
89+
// Fallback for standalone pass usage (e.g., iree-opt).
90+
targetBackend->buildTranslationPassPipeline(targetAttr, passManager);
91+
}
92+
7193
if (failed(runPipeline(passManager, variantOp))) {
7294
emitError(variantOp->getLoc())
7395
<< "failed to run translation of source executable to target "
@@ -88,6 +110,13 @@ struct TranslateAllExecutablesPass
88110
using IREE::HAL::impl::TranslateAllExecutablesPassBase<
89111
TranslateAllExecutablesPass>::TranslateAllExecutablesPassBase;
90112

113+
// Shared across all clones of this pass for thread-safe pipeline caching.
114+
// When MLIR clones this pass for parallel execution on different
115+
// ExecutableOps, the shared_ptr is copied so all clones share the same
116+
// cache.
117+
std::shared_ptr<PipelineCache> pipelineCache =
118+
std::make_shared<PipelineCache>();
119+
91120
void getDependentDialects(DialectRegistry &registry) const override {
92121
registry.insert<IREE::HAL::HALDialect>();
93122
registry.insert<bufferization::BufferizationDialect>();
@@ -103,8 +132,10 @@ struct TranslateAllExecutablesPass
103132
OpPassManager passManager(executableOp.getOperationName());
104133
for (const auto &targetName : gatherExecutableTargetNames(executableOp)) {
105134
passManager.addNestedPass<IREE::HAL::ExecutableVariantOp>(
106-
IREE::HAL::createTranslateTargetExecutableVariantsPass(
107-
{targetRegistry, targetName}));
135+
std::make_unique<TranslateTargetExecutableVariantsPass>(
136+
TranslateTargetExecutableVariantsPassOptions{targetRegistry,
137+
targetName},
138+
pipelineCache));
108139
}
109140

110141
IREE_COMPILER_TRACE_MESSAGE_DYNAMIC(INFO, executableOp.getSymName().str());

compiler/src/iree/compiler/Utils/PassUtils.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,48 @@
88
#define IREE_COMPILER_UTILS_PASSUTILS_H_
99

1010
#include <array>
11+
#include <memory>
12+
#include <mutex>
1113

14+
#include "llvm/ADT/DenseMap.h"
15+
#include "mlir/IR/Attributes.h"
1216
#include "mlir/Pass/Pass.h"
1317
#include "mlir/Pass/PassManager.h"
1418

1519
namespace mlir::iree_compiler {
1620

21+
// Thread-safe cache for compiled pass pipelines keyed by target attribute.
22+
// When multiple executable variants share the same target attribute, the pass
23+
// pipeline only needs to be constructed once. The cache is shared across clones
24+
// of the outer pass that MLIR creates for parallel execution on different
25+
// ExecutableOps via a shared_ptr.
26+
//
27+
// getOrCreate() returns a deep copy of the cached pipeline rather than a
28+
// reference because MLIR passes carry mutable state (analysis caches,
29+
// statistics) that is modified during execution. The outer per-ExecutableOp
30+
// passes run in parallel, so two threads processing different executables with
31+
// the same target attribute would race on a shared OpPassManager. The copy cost
32+
// is negligible compared to pipeline execution; the savings come from avoiding
33+
// redundant pipeline construction (registry lookups, dynamic pass creation) for
34+
// every variant.
35+
struct PipelineCache {
36+
std::mutex mutex;
37+
llvm::DenseMap<Attribute, std::unique_ptr<OpPassManager>> entries;
38+
39+
// Returns a deep copy of the cached pipeline for |targetAttr|, building it
40+
// on first access using |builder|. Thread-safe.
41+
OpPassManager getOrCreate(Attribute targetAttr, StringRef operationName,
42+
llvm::function_ref<void(OpPassManager &)> builder) {
43+
std::lock_guard<std::mutex> lock(mutex);
44+
auto &entry = entries[targetAttr];
45+
if (!entry) {
46+
entry = std::make_unique<OpPassManager>(operationName);
47+
builder(*entry);
48+
}
49+
return OpPassManager(*entry);
50+
}
51+
};
52+
1753
/// Constructs a pipeline of passes across multiple nested op types.
1854
///
1955
/// Usage:

0 commit comments

Comments
 (0)