Skip to content

Commit 3f6a88d

Browse files
committed
[WIP][MLIR][mlir-link] Add COMDAT resolution to LLVM dialect linker.
1 parent 0d4dab7 commit 3f6a88d

File tree

3 files changed

+108
-31
lines changed

3 files changed

+108
-31
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class LLVMSymbolLinkerInterface
1717
static void setVisibility(Operation *op, Visibility visibility);
1818
static bool isComdat(Operation *op);
1919
static std::optional<link::ComdatSelector> getComdatSelector(Operation *op);
20+
static LLVM::comdat::Comdat getComdatSelectionKind(Operation *op);
2021
static bool isDeclaration(Operation *op);
2122
static unsigned getBitWidth(Operation *op);
2223
static UnnamedAddr getUnnamedAddr(Operation *op);
@@ -35,6 +36,9 @@ class LLVMSymbolLinkerInterface
3536
LogicalResult initialize(ModuleOp src) override;
3637
LogicalResult finalize(ModuleOp dst) const override;
3738
Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state);
39+
LogicalResult resolveComdats(ModuleOp srcMod,
40+
SymbolTableCollection &collection);
41+
std::optional<link::ConflictResolution> getComdatResolution(Operation *);
3842

3943
template <typename structor_t>
4044
Operation *appendGlobalStructors(link::LinkState &state) {
@@ -125,6 +129,8 @@ class LLVMSymbolLinkerInterface
125129
private:
126130
DataLayoutSpecInterface dtla = {};
127131
TargetSystemSpecInterface targetSys = {};
132+
llvm::StringMap<std::pair<link::Comdat, link::ConflictResolution>>
133+
comdatResolution;
128134
};
129135

130136
} // namespace LLVM

mlir/include/mlir/Linker/LLVMLinkerMixin.h

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ struct ComdatSelector {
144144
ComdatKind kind;
145145
};
146146

147+
struct Comdat {
148+
ComdatKind kind;
149+
Operation *selectorOp;
150+
llvm::SmallPtrSet< Operation *, 2> users;
151+
};
152+
147153
//===----------------------------------------------------------------------===//
148154
// LLVMLinkerMixin
149155
//===----------------------------------------------------------------------===//
@@ -172,6 +178,12 @@ class LLVMLinkerMixin {
172178
if (derived.isComdat(pair.src))
173179
return true;
174180

181+
if (std::optional<link::ConflictResolution> res =
182+
derived.getComdatResolution(pair.src)) {
183+
// Comdats are either used or dropped as a group
184+
return res.value() == ConflictResolution::LinkFromSrc;
185+
}
186+
175187
Linkage srcLinkage = derived.getLinkage(pair.src);
176188

177189
// Always import variables with appending linkage.
@@ -218,6 +230,10 @@ class LLVMLinkerMixin {
218230
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
219231
};
220232

233+
if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) {
234+
return linkError("Linking ComdatOp with non-comdat op");
235+
}
236+
221237
Linkage srcLinkage = derived.getLinkage(pair.src);
222238
Linkage dstLinkage = derived.getLinkage(pair.dst);
223239

@@ -259,6 +275,11 @@ class LLVMLinkerMixin {
259275
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
260276
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
261277

278+
// We insert the computed comdat information into the dst module comdat op
279+
// Make sure it is linked in as we want
280+
if (derived.isComdat(pair.src))
281+
return ConflictResolution::LinkFromDst;
282+
262283
Linkage srcLinkage = derived.getLinkage(pair.src);
263284
Linkage dstLinkage = derived.getLinkage(pair.dst);
264285

@@ -352,37 +373,6 @@ class LLVMLinkerMixin {
352373
return ConflictResolution::LinkFromSrc;
353374
}
354375

355-
std::optional<ComdatSelector> srcComdatSel =
356-
derived.getComdatSelector(pair.src);
357-
std::optional<ComdatSelector> dstComdatSel =
358-
derived.getComdatSelector(pair.dst);
359-
if (srcComdatSel.has_value() && dstComdatSel.has_value()) {
360-
auto srcComdatName = srcComdatSel->name;
361-
auto dstComdatName = dstComdatSel->name;
362-
auto srcComdat = srcComdatSel->kind;
363-
auto dstComdat = dstComdatSel->kind;
364-
if (srcComdatName != dstComdatName) {
365-
llvm_unreachable("Comdat selector names don't match");
366-
}
367-
if (srcComdat != dstComdat) {
368-
llvm_unreachable("Comdat selector kinds don't match");
369-
}
370-
371-
if (srcComdat == mlir::LLVM::comdat::Comdat::Any) {
372-
return ConflictResolution::LinkFromDst;
373-
}
374-
if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) {
375-
return ConflictResolution::Failure;
376-
}
377-
if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) {
378-
return ConflictResolution::LinkFromDst;
379-
}
380-
if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) {
381-
return ConflictResolution::LinkFromDst;
382-
}
383-
llvm_unreachable("unimplemented comdat kind");
384-
}
385-
386376
// If we reach here, we have two external definitions that can't be resolved
387377
// This is typically an error case in LLVM linking
388378
if (isExternalLinkage(srcLinkage) && isExternalLinkage(dstLinkage) &&

mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ LLVM::LLVMSymbolLinkerInterface::getComdatSelector(Operation *op) {
108108
return {{comdatSelector.getSymName(), comdatSelector.getComdat()}};
109109
}
110110

111+
LLVM::comdat::Comdat
112+
LLVM::LLVMSymbolLinkerInterface::getComdatSelectionKind(Operation *op) {
113+
if (auto selector = dyn_cast<LLVM::ComdatSelectorOp>(op))
114+
return selector.getComdat();
115+
llvm_unreachable("expected selector op");
116+
}
117+
111118
// Return true if the primary definition of this global value is outside of
112119
// the current translation unit.
113120
bool LLVM::LLVMSymbolLinkerInterface::isDeclaration(Operation *op) {
@@ -547,6 +554,80 @@ Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob,
547554
llvm_unreachable("unexpected operation");
548555
}
549556

557+
static ConflictResolution computeComdatResolution() {}
558+
559+
LogicalResult LLVM::LLVMSymbolLinkerInterface::resolveComdats(
560+
ModuleOp srcMod, SymbolTableCollection &collection) {
561+
LLVM::ComdatOp srcComdatOp;
562+
LLVM::ComdatOp dstComdatOp;
563+
for (auto &op : srcMod) {
564+
if (auto comdatOp = dyn_cast<LLVM::ComdatOp>(op)) {
565+
srcComdatOp = comdatOp;
566+
break;
567+
}
568+
}
569+
570+
SymbolUserMap srcSymbolUsers(collection,
571+
srcComdatOp->getParentOfType<ModuleOp>());
572+
// Get current resolved ComdatOp or insert srcComdatOp into summary
573+
// TODO: use comdat summary to find conflict
574+
if (auto it = summary.find(getSymbol(srcComdatOp)); it != summary.end()) {
575+
dstComdatOp = cast<LLVM::ComdatOp>(*it);
576+
} else {
577+
summary[getSymbol(srcComdatOp)] = srcComdatOp;
578+
for (Operation &op : srcComdatOp.getBody().front()) {
579+
ArrayRef<Operation *> users = srcSymbolUsers.getUsers(&op);
580+
comdatResolution.try_emplace(
581+
getSymbol(&op),
582+
std::make_pair(link::Comdat{getComdatSelectionKind(&op),
583+
&op,
584+
{users.begin(), users.end()}},
585+
ConflictResolution::LinkFromSrc));
586+
}
587+
return success();
588+
}
589+
590+
for (Operation &op : srcComdatOp.getBody().front()) {
591+
auto srcSelector = cast<LLVM::ComdatSelectorOp>(op);
592+
// TODO: use custom enum for comdat?
593+
// If no conflict choose src
594+
auto res = ConflictResolution::LinkFromSrc;
595+
if (auto dstComdatIt = comdatResolution.find(getSymbol(&op));
596+
dstComdatIt != comdatResolution.end()) {
597+
res = computeComdatResolution(/*TODO*/);
598+
// remove dst ops from summary if src selected
599+
}
600+
switch (res) {
601+
case ConflictResolution::LinkFromSrc: {
602+
ArrayRef<Operation *> users = srcSymbolUsers.getUsers(&op);
603+
comdatResolution.try_emplace(
604+
getSymbol(srcSelector),
605+
std::make_pair(link::Comdat{getComdatSelectionKind(srcSelector),
606+
srcSelector,
607+
{users.begin(), users.end()}},
608+
ConflictResolution::LinkFromSrc));
609+
break;
610+
}
611+
case ConflictResolution::LinkFromDst:
612+
case ConflictResolution::LinkFromBothAndRenameDst:
613+
case ConflictResolution::LinkFromBothAndRenameSrc:
614+
case ConflictResolution::Failure:
615+
return failure();
616+
}
617+
}
618+
return success();
619+
}
620+
621+
std::optional<link::ConflictResolution>
622+
LLVM::LLVMSymbolLinkerInterface::getComdatResolution(Operation *op) {
623+
if (hasComdat(op)) {
624+
if (auto resIt = comdatResolution.find(getSymbol(op));
625+
resIt != comdatResolution.end())
626+
return resIt->second.second;
627+
}
628+
return {};
629+
}
630+
550631
//===----------------------------------------------------------------------===//
551632
// registerLinkerInterface
552633
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)