Skip to content

Commit 8b7f763

Browse files
committed
[MLIR][mlir-link] Add COMDAT resolution to LLVM dialect linker.
1 parent 0d4dab7 commit 8b7f763

File tree

14 files changed

+371
-115
lines changed

14 files changed

+371
-115
lines changed

clang/include/clang/CIR/Interfaces/CIRLinkerInterface.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@ class CIRSymbolLinkerInterface
4545

4646
static bool isComdat(Operation *op);
4747

48-
static std::optional<mlir::link::ComdatSelector>
49-
getComdatSelector(Operation *op);
48+
static bool hasComdat(Operation *op);
49+
50+
static const link::Comdat *getComdatResolution(Operation *op);
51+
52+
static bool selectedByComdat(Operation *op);
53+
54+
static void updateNoDeduplicate(Operation *op);
5055

5156
static Visibility getVisibility(Operation *op);
5257

clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,23 @@ bool CIRSymbolLinkerInterface::isComdat(Operation *op) {
9999
return false;
100100
}
101101

102-
std::optional<link::ComdatSelector>
103-
CIRSymbolLinkerInterface::getComdatSelector(Operation *op) {
104-
// TODO(frabert): Extracting comdat info from CIR is not implemented yet
105-
return std::nullopt;
102+
bool CIRSymbolLinkerInterface::hasComdat(Operation *op) {
103+
// TODO: Extracting comdat info from CIR is not implemented yet
104+
return false;
105+
}
106+
107+
const link::Comdat *
108+
CIRSymbolLinkerInterface::getComdatResolution(Operation *op) {
109+
return nullptr;
106110
}
107111

112+
bool CIRSymbolLinkerInterface::selectedByComdat(Operation *op) {
113+
// TODO: Extracting comdat info from CIR is not implemented yet
114+
llvm_unreachable("comdat resolution not implemented for CIR");
115+
}
116+
117+
void CIRSymbolLinkerInterface::updateNoDeduplicate(Operation *op) {}
118+
108119
Visibility CIRSymbolLinkerInterface::getVisibility(Operation *op) {
109120
if (auto gv = dyn_cast<GlobalOp>(op))
110121
return toLLVMVisibility(gv.getGlobalVisibility());

mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ class LLVMSymbolLinkerInterface
1616
static Visibility getVisibility(Operation *op);
1717
static void setVisibility(Operation *op, Visibility visibility);
1818
static bool isComdat(Operation *op);
19-
static std::optional<link::ComdatSelector> getComdatSelector(Operation *op);
19+
static bool hasComdat(Operation *op);
20+
static SymbolRefAttr getComdatSymbol(Operation *op);
21+
static LLVM::comdat::Comdat getComdatSelectionKind(Operation *op);
2022
static bool isDeclaration(Operation *op);
2123
static unsigned getBitWidth(Operation *op);
2224
static UnnamedAddr getUnnamedAddr(Operation *op);
@@ -34,7 +36,19 @@ class LLVMSymbolLinkerInterface
3436
dependencies(Operation *op, SymbolTableCollection &collection) const override;
3537
LogicalResult initialize(ModuleOp src) override;
3638
LogicalResult finalize(ModuleOp dst) const override;
39+
LogicalResult moduleOpSummary(ModuleOp src,
40+
SymbolTableCollection &collection) override;
3741
Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state);
42+
Operation *appendComdatOps(ArrayRef<Operation *> globs, LLVM::ComdatOp comdat,
43+
link::LinkState &state);
44+
link::ComdatResolution
45+
computeComdatResolution(Operation *, SymbolTableCollection &, link::Comdat *);
46+
LogicalResult resolveComdats(ModuleOp srcMod,
47+
SymbolTableCollection &collection);
48+
const link::Comdat *getComdatResolution(Operation *op) const;
49+
bool selectedByComdat(Operation *op) const;
50+
void dropReplacedComdat(Operation *op) const;
51+
static void updateNoDeduplicate(Operation *op);
3852

3953
template <typename structor_t>
4054
Operation *appendGlobalStructors(link::LinkState &state) {
@@ -79,22 +93,27 @@ class LLVMSymbolLinkerInterface
7993
ArrayRef<Attribute> priorities = structor.getPriorities().getValue();
8094
ArrayRef<Attribute> data = structor.getData().getValue();
8195

82-
for (auto [idx, dataAttr] : llvm::enumerate(data)) {
96+
for (auto [idx, structor] : llvm::enumerate(structorList)) {
97+
auto structorSymbol = cast<FlatSymbolRefAttr>(structor);
98+
// Skip constructors not included based on COMDAT
99+
if (!summary.contains(structorSymbol.getValue()))
100+
continue;
101+
102+
auto dataAttr = data[idx];
83103
// data value is either #llvm.zero or symbol ref
84104
// if it is zero, we always have to include the value
85105
// if it is a symbol ref, we have to check if the symbol
86106
// from the same module is being used
87-
//
88107
if (auto globalSymbol = dyn_cast<FlatSymbolRefAttr>(dataAttr)) {
89-
auto globalOp = summary.lookup(globalSymbol.getValue());
108+
Operation *globalOp = summary.lookup(globalSymbol.getValue());
90109
assert(globalOp && "structor referenced global not in summary?");
91110
// globals are definde at module level
92111
if (globalOp->getParentOp() != op->getParentOp())
93112
continue;
94113
}
95114

96115
newData.push_back(dataAttr);
97-
newStructorList.push_back(structorList[idx]);
116+
newStructorList.push_back(structor);
98117
newPriorities.push_back(priorities[idx]);
99118
}
100119
}
@@ -125,6 +144,7 @@ class LLVMSymbolLinkerInterface
125144
private:
126145
DataLayoutSpecInterface dtla = {};
127146
TargetSystemSpecInterface targetSys = {};
147+
llvm::StringMap<link::Comdat> comdatResolution;
128148
};
129149

130150
} // namespace LLVM

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,17 @@ static UnnamedAddr getMinUnnamedAddr(UnnamedAddr lhs, UnnamedAddr rhs) {
139139

140140
using ComdatKind = LLVM::comdat::Comdat;
141141

142-
struct ComdatSelector {
143-
StringRef name;
142+
struct Comdat {
144143
ComdatKind kind;
144+
Operation *selectorOp;
145+
llvm::SmallPtrSet<Operation *, 2> users;
146+
};
147+
148+
enum class ComdatResolution {
149+
LinkFromSrc,
150+
LinkFromDst,
151+
LinkFromBoth,
152+
Failure,
145153
};
146154

147155
//===----------------------------------------------------------------------===//
@@ -172,6 +180,20 @@ class LLVMLinkerMixin {
172180
if (derived.isComdat(pair.src))
173181
return true;
174182

183+
// Thrown away symbol can affect the visibility
184+
if (pair.dst) {
185+
Visibility srcVisibility = derived.getVisibility(pair.src);
186+
Visibility dstVisibility = derived.getVisibility(pair.dst);
187+
Visibility visibility = getMinVisibility(srcVisibility, dstVisibility);
188+
189+
derived.setVisibility(pair.src, visibility);
190+
derived.setVisibility(pair.dst, visibility);
191+
}
192+
if (derived.hasComdat(pair.src)) {
193+
// operations with COMDAT are selected as a group
194+
return derived.selectedByComdat(pair.src);
195+
}
196+
175197
Linkage srcLinkage = derived.getLinkage(pair.src);
176198

177199
// Always import variables with appending linkage.
@@ -218,6 +240,10 @@ class LLVMLinkerMixin {
218240
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
219241
};
220242

243+
if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) {
244+
return linkError("Linking ComdatOp with non-comdat op");
245+
}
246+
221247
Linkage srcLinkage = derived.getLinkage(pair.src);
222248
Linkage dstLinkage = derived.getLinkage(pair.dst);
223249

@@ -340,6 +366,11 @@ class LLVMLinkerMixin {
340366
if (isWeakForLinker(srcLinkage)) {
341367
assert(!isExternalWeakLinkage(dstLinkage));
342368
assert(!isAvailableExternallyLinkage(dstLinkage));
369+
const Comdat *comdat = derived.getComdatResolution(pair.src);
370+
if (comdat && comdat->kind == ComdatKind::NoDeduplicate) {
371+
derived.updateNoDeduplicate(pair.src);
372+
return ConflictResolution::LinkFromBothAndRenameSrc;
373+
}
343374
if (isLinkOnceLinkage(dstLinkage) && isWeakLinkage(srcLinkage)) {
344375
return ConflictResolution::LinkFromSrc;
345376
}
@@ -349,38 +380,12 @@ class LLVMLinkerMixin {
349380

350381
if (isWeakForLinker(dstLinkage)) {
351382
assert(isExternalLinkage(srcLinkage));
352-
return ConflictResolution::LinkFromSrc;
353-
}
354-
355-
std::optional<ComdatSelector> srcComdatSel =
356-
derived.getComdatSelector(pair.src);
357-
std::optional<ComdatSelector> dstComdatSel =
358-
derived.getComdatSelector(pair.dst);
359-
if (srcComdatSel.has_value() && dstComdatSel.has_value()) {
360-
auto srcComdatName = srcComdatSel->name;
361-
auto dstComdatName = dstComdatSel->name;
362-
auto srcComdat = srcComdatSel->kind;
363-
auto dstComdat = dstComdatSel->kind;
364-
if (srcComdatName != dstComdatName) {
365-
llvm_unreachable("Comdat selector names don't match");
366-
}
367-
if (srcComdat != dstComdat) {
368-
llvm_unreachable("Comdat selector kinds don't match");
369-
}
370-
371-
if (srcComdat == mlir::LLVM::comdat::Comdat::Any) {
372-
return ConflictResolution::LinkFromDst;
373-
}
374-
if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) {
375-
return ConflictResolution::Failure;
383+
const Comdat *comdat = derived.getComdatResolution(pair.dst);
384+
if (comdat && comdat->kind == ComdatKind::NoDeduplicate) {
385+
derived.updateNoDeduplicate(pair.dst);
386+
return ConflictResolution::LinkFromBothAndRenameDst;
376387
}
377-
if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) {
378-
return ConflictResolution::LinkFromDst;
379-
}
380-
if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) {
381-
return ConflictResolution::LinkFromDst;
382-
}
383-
llvm_unreachable("unimplemented comdat kind");
388+
return ConflictResolution::LinkFromSrc;
384389
}
385390

386391
// If we reach here, we have two external definitions that can't be resolved

mlir/include/mlir/Linker/LinkerInterface.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class LinkState {
6060
return builder.create<Op>(location, std::forward<Args>(args)...);
6161
}
6262

63+
OpBuilder &getBuilder() { return builder; };
64+
6365
private:
6466
// Private constructor used by nest()
6567
LinkState(ModuleOp dst, std::shared_ptr<IRMapping> mapping)
@@ -145,7 +147,8 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
145147

146148
/// Perform tasks that need to be computed on whole-module basis before actual summary.
147149
/// E.g. Pre-compute COMDAT resolution before actually linking the modules.
148-
virtual LogicalResult moduleOpSummary(ModuleOp module) {
150+
virtual LogicalResult moduleOpSummary(ModuleOp module,
151+
SymbolTableCollection &collection) {
149152
return success();
150153
}
151154

@@ -282,9 +285,10 @@ class SymbolLinkerInterfaces {
282285
return Conflict::noConflict(src);
283286
}
284287

285-
LogicalResult moduleOpSummary(ModuleOp src) {
288+
LogicalResult moduleOpSummary(ModuleOp src,
289+
SymbolTableCollection &collection) {
286290
for (SymbolLinkerInterface *linker : interfaces) {
287-
if (failed(linker->moduleOpSummary(src)))
291+
if (failed(linker->moduleOpSummary(src, collection)))
288292
return failure();
289293
}
290294
return success();

0 commit comments

Comments
 (0)