Skip to content

Commit a5024cd

Browse files
[MLIR][LLVM] More on CG Profile: support null function entries (#137269)
1 parent b509f7c commit a5024cd

File tree

6 files changed

+43
-21
lines changed

6 files changed

+43
-21
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

+5-3
Original file line numberDiff line numberDiff line change
@@ -1370,9 +1370,11 @@ def ModuleFlagCGProfileEntryAttr
13701370
]>]
13711371
```
13721372
}];
1373-
let parameters = (ins "FlatSymbolRefAttr":$from,
1374-
"FlatSymbolRefAttr":$to,
1375-
"uint64_t":$count);
1373+
let parameters = (
1374+
ins OptionalParameter<"FlatSymbolRefAttr">:$from,
1375+
OptionalParameter<"FlatSymbolRefAttr">:$to,
1376+
"uint64_t":$count);
1377+
13761378
let assemblyFormat = "`<` struct(params) `>`";
13771379
}
13781380

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,19 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
281281

282282
if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
283283
for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) {
284-
llvm::Function *fromFn =
285-
moduleTranslation.lookupFunction(entry.getFrom().getValue());
286-
llvm::Function *toFn =
287-
moduleTranslation.lookupFunction(entry.getTo().getValue());
284+
llvm::Metadata *fromMetadata =
285+
entry.getFrom()
286+
? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction(
287+
entry.getFrom().getValue()))
288+
: nullptr;
289+
llvm::Metadata *toMetadata =
290+
entry.getTo()
291+
? llvm::ValueAsMetadata::get(
292+
moduleTranslation.lookupFunction(entry.getTo().getValue()))
293+
: nullptr;
294+
288295
llvm::Metadata *vals[] = {
289-
llvm::ValueAsMetadata::get(fromFn), llvm::ValueAsMetadata::get(toFn),
296+
fromMetadata, toMetadata,
290297
mdb.createConstant(llvm::ConstantInt::get(
291298
llvm::Type::getInt64Ty(context), entry.getCount()))};
292299
nodes.push_back(llvm::MDNode::get(context, vals));

mlir/lib/Target/LLVMIR/ModuleImport.cpp

+19-6
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,14 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
521521

522522
static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
523523
llvm::MDTuple *mdTuple) {
524-
auto getFunctionSymbol = [&](const llvm::MDOperand &funcMDO) {
525-
auto *f = cast<llvm::ValueAsMetadata>(funcMDO);
524+
auto getLLVMFunction =
525+
[&](const llvm::MDOperand &funcMDO) -> llvm::Function * {
526+
auto *f = cast_or_null<llvm::ValueAsMetadata>(funcMDO);
527+
// nullptr is a valid value for the function pointer.
528+
if (!f)
529+
return nullptr;
526530
auto *llvmFn = cast<llvm::Function>(f->getValue()->stripPointerCasts());
527-
return FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName());
531+
return llvmFn;
528532
};
529533

530534
// Each tuple element becomes one ModuleFlagCGProfileEntryAttr.
@@ -535,9 +539,17 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
535539
llvm::Constant *llvmConstant =
536540
cast<llvm::ConstantAsMetadata>(cgEntry->getOperand(2))->getValue();
537541
uint64_t count = cast<llvm::ConstantInt>(llvmConstant)->getZExtValue();
542+
auto *fromFn = getLLVMFunction(cgEntry->getOperand(0));
543+
auto *toFn = getLLVMFunction(cgEntry->getOperand(1));
544+
// FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName());
538545
cgProfile.push_back(ModuleFlagCGProfileEntryAttr::get(
539-
mlirModule->getContext(), getFunctionSymbol(cgEntry->getOperand(0)),
540-
getFunctionSymbol(cgEntry->getOperand(1)), count));
546+
mlirModule->getContext(),
547+
fromFn ? FlatSymbolRefAttr::get(mlirModule->getContext(),
548+
fromFn->getName())
549+
: nullptr,
550+
toFn ? FlatSymbolRefAttr::get(mlirModule->getContext(), toFn->getName())
551+
: nullptr,
552+
count));
541553
}
542554
return ArrayAttr::get(mlirModule->getContext(), cgProfile);
543555
}
@@ -570,7 +582,8 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() {
570582

571583
if (!valAttr) {
572584
emitWarning(mlirModule.getLoc())
573-
<< "unsupported module flag value: " << diagMD(val, llvmModule.get());
585+
<< "unsupported module flag value for key '" << key->getString()
586+
<< "' : " << diagMD(val, llvmModule.get());
574587
continue;
575588
}
576589

mlir/test/Dialect/LLVMIR/module-roundtrip.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module {
99
#llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
1010
#llvm.mlir.module_flag<append, "CG Profile", [
1111
#llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
12-
#llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
12+
#llvm.cgprofile_entry<from = @from, count = 222>,
1313
#llvm.cgprofile_entry<from = @to, to = @from, count = 222>
1414
]>]
1515
}
@@ -23,6 +23,6 @@ module {
2323
// CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
2424
// CHECK-SAME: #llvm.mlir.module_flag<append, "CG Profile", [
2525
// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
26-
// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
26+
// CHECK-SAME: #llvm.cgprofile_entry<from = @from, count = 222>,
2727
// CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
2828
// CHECK-SAME: ]>]

mlir/test/Target/LLVMIR/Import/module-flags.ll

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
; CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">]
2020

2121
; // -----
22-
; expected-warning@-2 {{unsupported module flag value: !4 = !{!"foo", i32 1}}}
22+
; expected-warning@-2 {{unsupported module flag value for key 'qux' : !4 = !{!"foo", i32 1}}}
2323
!10 = !{ i32 1, !"foo", i32 1 }
2424
!11 = !{ i32 4, !"bar", i32 37 }
2525
!12 = !{ i32 2, !"qux", i32 42 }
@@ -36,11 +36,11 @@ declare void @to()
3636
!20 = !{i32 5, !"CG Profile", !21}
3737
!21 = distinct !{!22, !23, !24}
3838
!22 = !{ptr @from, ptr @to, i64 222}
39-
!23 = !{ptr @from, ptr @from, i64 222}
39+
!23 = !{ptr @from, null, i64 222}
4040
!24 = !{ptr @to, ptr @from, i64 222}
4141

4242
; CHECK: llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
4343
; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
44-
; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
44+
; CHECK-SAME: #llvm.cgprofile_entry<from = @from, count = 222>,
4545
; CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
4646
; CHECK-SAME: ]>]

mlir/test/Target/LLVMIR/llvmir.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -2840,7 +2840,7 @@ module {
28402840

28412841
llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
28422842
#llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
2843-
#llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
2843+
#llvm.cgprofile_entry<from = @from, count = 222>,
28442844
#llvm.cgprofile_entry<from = @to, to = @from, count = 222>
28452845
]>]
28462846
llvm.func @from(i32)
@@ -2851,7 +2851,7 @@ llvm.func @to()
28512851
// CHECK: ![[#CGPROF]] = !{i32 5, !"CG Profile", ![[#LIST:]]}
28522852
// CHECK: ![[#LIST]] = distinct !{![[#ENTRY_A:]], ![[#ENTRY_B:]], ![[#ENTRY_C:]]}
28532853
// CHECK: ![[#ENTRY_A]] = !{ptr @from, ptr @to, i64 222}
2854-
// CHECK: ![[#ENTRY_B]] = !{ptr @from, ptr @from, i64 222}
2854+
// CHECK: ![[#ENTRY_B]] = !{ptr @from, null, i64 222}
28552855
// CHECK: ![[#ENTRY_C]] = !{ptr @to, ptr @from, i64 222}
28562856
// CHECK: ![[#DBG]] = !{i32 2, !"Debug Info Version", i32 3}
28572857

0 commit comments

Comments
 (0)