Skip to content

Commit ee83e69

Browse files
authored
Callconv fixup (#2645)
* Callconv fixup * fix
1 parent 9aa1bec commit ee83e69

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,14 @@ bool needsReReturning(llvm::Argument *arg, size_t &sret_idx,
18781878
return true;
18791879
}
18801880

1881+
static bool isOpaque(llvm::Type *T) {
1882+
#if LLVM_VERSION_MAJOR >= 20
1883+
return T->isPointerTy();
1884+
#else
1885+
return T->isOpaquePointerTy();
1886+
#endif
1887+
}
1888+
18811889
// TODO, for sret/sret_v check if it actually stores the jlvalue_t's into the
18821890
// sret If so, confirm that those values are saved elsewhere in a returnroot
18831891
void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
@@ -2339,6 +2347,8 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
23392347
assert(sretCount == Types.size());
23402348
}
23412349

2350+
auto &DL = F->getParent()->getDataLayout();
2351+
23422352
// TODO fix caller side
23432353
for (auto CI : callers) {
23442354
auto Attrs = CI->getAttributes();
@@ -2400,7 +2410,10 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
24002410

24012411
bool handled = false;
24022412
if (auto AI = dyn_cast<AllocaInst>(getBaseObject(val, false))) {
2403-
if (AI->getAllocatedType() == Types[sretCount]) {
2413+
if (AI->getAllocatedType() == Types[sretCount] ||
2414+
(isOpaque(AI->getType()) &&
2415+
DL.getTypeSizeInBits(AI->getAllocatedType()) ==
2416+
DL.getTypeSizeInBits(Types[sretCount]))) {
24042417
AI->replaceAllUsesWith(gep);
24052418
AI->eraseFromParent();
24062419
handled = true;
@@ -2424,8 +2437,9 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
24242437
}
24252438

24262439
postCallReplacements.emplace_back(val, gep, Types[sretCount],
2427-
sret_jlvalue);
2428-
preCallReplacements.emplace_back(val, gep, Types[sretCount]);
2440+
should_sret);
2441+
if (!isWriteOnly(CI, i))
2442+
preCallReplacements.emplace_back(val, gep, Types[sretCount]);
24292443
}
24302444

24312445
if (roots_AT && reroot_enzyme_srets.count(i)) {

0 commit comments

Comments
 (0)