Skip to content

Commit e08816f

Browse files
committed
[MLIR][mlir-link] Make linker interfaces multithreaded
1 parent 0002fa4 commit e08816f

File tree

8 files changed

+146
-64
lines changed

8 files changed

+146
-64
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class LLVMSymbolLinkerInterface
2929
static llvm::StringRef getSection(Operation *op);
3030
static uint32_t getAddressSpace(Operation *op);
3131
StringRef getSymbol(Operation *op) const override;
32-
Operation *materialize(Operation *src, link::LinkState &state) const override;
32+
Operation *materialize(Operation *src, link::LinkState &state) override;
3333
SmallVector<Operation *>
3434
dependencies(Operation *op, SymbolTableCollection &collection) const override;
3535
LogicalResult initialize(ModuleOp src) override;

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ class LLVMLinkerMixin {
155155
return static_cast<const DerivedLinkerInterface &>(*this);
156156
}
157157

158+
DerivedLinkerInterface &getDerived() {
159+
return static_cast<DerivedLinkerInterface &>(*this);
160+
}
161+
158162
public:
159163
bool isDeclarationForLinker(Operation *op) const {
160164
const DerivedLinkerInterface &derived = getDerived();

mlir/include/mlir/Linker/Linker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/BuiltinOps.h"
1313

1414
#include "mlir/Linker/LinkerInterface.h"
15+
#include <mutex>
1516
namespace mlir::link {
1617

1718
/// These are gathered alphabetically sorted linker options
@@ -101,6 +102,9 @@ class Linker {
101102

102103
/// Modules registry used if `keepModulesAlive` is true
103104
std::vector<OwningOpRef<ModuleOp>> modules;
105+
106+
/// Mutex to protect modules vector and initialization during parallel addModule calls
107+
std::mutex linkerMutex;
104108
};
105109

106110
} // namespace mlir::link

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/DenseMap.h"
2222
#include "llvm/Support/Error.h"
2323
#include <memory>
24+
#include <mutex>
2425

2526
namespace mlir::link {
2627

@@ -37,8 +38,9 @@ enum LinkerFlags {
3738
class LinkState {
3839
public:
3940
LinkState(ModuleOp dst, mlir::SymbolTableCollection &symbolTableCollection)
40-
: mapping(std::make_shared<IRMapping>()), builder(dst.getBodyRegion()),
41-
symbolTableCollection(symbolTableCollection), moduleMaps() {}
41+
: mapping(std::make_shared<IRMapping>()), mutex(std::make_shared<std::mutex>()),
42+
builder(dst.getBodyRegion()), symbolTableCollection(symbolTableCollection),
43+
moduleMaps() {}
4244

4345
Operation *clone(Operation *src);
4446
Operation *cloneWithoutRegions(Operation *src);
@@ -49,7 +51,7 @@ class LinkState {
4951

5052
LinkState nest(ModuleOp submod) const;
5153

52-
IRMapping &getMapping();
54+
std::pair<IRMapping &, std::mutex &> getMapping();
5355
SymbolTableCollection &getSymbolTableCollection() {
5456
return symbolTableCollection;
5557
}
@@ -63,11 +65,14 @@ class LinkState {
6365
private:
6466
// Private constructor used by nest()
6567
LinkState(ModuleOp dst, std::shared_ptr<IRMapping> mapping,
68+
std::shared_ptr<std::mutex> mutex,
6669
SymbolTableCollection &symbolTableCollection)
67-
: mapping(std::move(mapping)), builder(dst.getBodyRegion()),
70+
: mapping(std::move(mapping)), mutex(std::move(mutex)),
71+
builder(dst.getBodyRegion()),
6872
symbolTableCollection(symbolTableCollection), moduleMaps() {}
6973

7074
std::shared_ptr<IRMapping> mapping;
75+
std::shared_ptr<std::mutex> mutex;
7176
OpBuilder builder;
7277
SymbolTableCollection &symbolTableCollection;
7378
DenseMap<ModuleOp, SymbolUserMap> moduleMaps;
@@ -95,7 +100,7 @@ class LinkerInterface : public DialectInterface::Base<ConcreteType> {
95100
virtual LogicalResult finalize(ModuleOp dst) const { return success(); }
96101

97102
/// Link operations from current summary using state builder
98-
virtual LogicalResult link(LinkState &state) const = 0;
103+
virtual LogicalResult link(LinkState &state) = 0;
99104
};
100105

101106
//===----------------------------------------------------------------------===//
@@ -140,7 +145,7 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
140145
virtual void registerForLink(Operation *op, SymbolTableCollection &collection) = 0;
141146

142147
/// Materialize new operation for the given conflict src operation.
143-
virtual Operation *materialize(Operation *src, LinkState &state) const {
148+
virtual Operation *materialize(Operation *src, LinkState &state) {
144149
return state.clone(src);
145150
}
146151

@@ -178,7 +183,7 @@ class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
178183
using SymbolLinkerInterface::SymbolLinkerInterface;
179184

180185
/// Link operations from current summary using state builder
181-
LogicalResult link(LinkState &state) const override;
186+
LogicalResult link(LinkState &state) override;
182187

183188
/// Returns the symbol for the given operation.
184189
StringRef getSymbol(Operation *op) const override;
@@ -211,6 +216,9 @@ class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
211216

212217
// Operations that are to be linked with unique names.
213218
SetVector<Operation *> uniqued;
219+
220+
// Mutex to protect summary and uniqued during parallel summarization.
221+
mutable std::mutex summaryMutex;
214222
};
215223

216224
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ StringRef LLVM::LLVMSymbolLinkerInterface::getSymbol(Operation *op) const {
246246

247247
Operation *
248248
LLVM::LLVMSymbolLinkerInterface::materialize(Operation *src,
249-
LinkState &state) const {
250-
auto derived = LinkerMixin::getDerived();
249+
LinkState &state) {
250+
auto &derived = LinkerMixin::getDerived();
251251
// empty append means that we either have single module or that something went
252252
// wrong
253253
if (isAppendingLinkage(derived.getLinkage(src)) && !append.empty()) {
@@ -427,7 +427,7 @@ getAppendedOpWithInitRegion(llvm::ArrayRef<mlir::Operation *> globs,
427427
auto elemType = originalType.getElementType();
428428
size_t elemCount = 0;
429429

430-
IRMapping &mapping = state.getMapping();
430+
auto [mapping, mutex] = state.getMapping();
431431
auto builder = OpBuilder(targetRegion);
432432
std::vector<Value> values;
433433
std::vector<std::vector<int64_t>> positions;
@@ -442,8 +442,12 @@ getAppendedOpWithInitRegion(llvm::ArrayRef<mlir::Operation *> globs,
442442
if (isa<LLVM::UndefOp, LLVM::ReturnOp, LLVM::InsertValueOp>(op))
443443
continue;
444444

445-
Operation *cloned = builder.clone(op, mapping);
446-
mapping.map(&op, cloned);
445+
Operation *cloned;
446+
{
447+
std::lock_guard<std::mutex> lock(mutex);
448+
cloned = builder.clone(op, mapping);
449+
mapping.map(&op, cloned);
450+
}
447451
// LLVM dialect does not have multiple result operations
448452
// zero result operation should not appear in this context
449453
// unless its a return which we skip

mlir/lib/IR/BuiltinLinkerInterface.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/IR/BuiltinLinkerInterface.h"
1414
#include "mlir/IR/BuiltinDialect.h"
15+
#include "mlir/IR/Threading.h"
1516
#include "mlir/Linker/LinkerInterface.h"
1617

1718
using namespace mlir;
@@ -36,16 +37,18 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface {
3637

3738
LogicalResult summarize(ModuleOp src, unsigned flags,
3839
SymbolTableCollection &collection) override {
39-
WalkResult result = src.walk([&](Operation *op) {
40-
if (op == src)
41-
return WalkResult::advance();
42-
43-
if (summarize(op, flags, /*forDependency=*/false, collection).failed())
44-
return WalkResult::interrupt();
45-
return WalkResult::advance();
40+
// Collect all operations to process in parallel
41+
SmallVector<Operation *> ops;
42+
src.walk([&](Operation *op) {
43+
if (op != src)
44+
ops.push_back(op);
4645
});
4746

48-
return failure(result.wasInterrupted());
47+
// Process operations in parallel
48+
return failableParallelForEach(
49+
src.getContext(), ops, [&](Operation *op) {
50+
return summarize(op, flags, /*forDependency=*/false, collection);
51+
});
4952
}
5053

5154
LogicalResult summarize(Operation *op, unsigned flags, bool forDependency,
@@ -70,15 +73,15 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface {
7073
linker->registerForLink(op, collection);
7174
}
7275

73-
for (Operation *dep : linker->dependencies(op, symbolTableCollection)) {
74-
if (summarize(dep, flags, /*forDependency=*/true, collection).failed())
75-
return failure();
76-
}
76+
SmallVector<Operation *> deps = linker->dependencies(op, symbolTableCollection);
77+
auto res = failableParallelForEach(getContext(), deps, [&](Operation *dep) {
78+
return summarize(dep, flags, /*forDependency=*/true, collection);
79+
});
7780

78-
return success();
81+
return res;
7982
}
8083

81-
LogicalResult link(LinkState &state) const override {
84+
LogicalResult link(LinkState &state) override {
8285
return symbolLinkers.link(state);
8386
}
8487

mlir/lib/Linker/Linker.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,26 @@ LogicalResult Linker::addModule(OwningOpRef<ModuleOp> src, bool onlyNeeded) {
5353
}
5454

5555
LogicalResult Linker::addModule(OwningOpRef<ModuleOp> src, unsigned flags) {
56-
ModuleOp mod = [&] {
56+
ModuleOp mod;
57+
58+
{
59+
std::lock_guard<std::mutex> lock(linkerMutex);
60+
5761
if (options.shouldKeepModulesAlive()) {
5862
modules.push_back(std::move(src));
59-
return modules.back().get();
63+
mod = modules.back().get();
64+
} else {
65+
mod = src.get();
6066
}
61-
return src.get();
62-
}();
6367

64-
// If this is the first module, setup the linker based on it
65-
if (!composite) {
66-
if (failed(initializeLinker(mod)))
67-
return failure();
68+
// If this is the first module, setup the linker based on it
69+
if (!composite) {
70+
if (failed(initializeLinker(mod)))
71+
return failure();
6872

69-
// We always override from source for the first module.
70-
flags &= LinkerFlags::OverrideFromSrc;
73+
// We always override from source for the first module.
74+
flags &= LinkerFlags::OverrideFromSrc;
75+
}
7176
}
7277

7378
return summarize(mod, flags);

0 commit comments

Comments
 (0)