@@ -1878,6 +1878,14 @@ bool needsReReturning(llvm::Argument *arg, size_t &sret_idx,
18781878 return true ;
18791879}
18801880
1881+ static bool isOpaque (llvm::Type *T) {
1882+ #if LLVM_VERSION_MAJOR >= 20
1883+ return T->isPointerTy ();
1884+ #else
1885+ return T->isOpaquePointerTy ();
1886+ #endif
1887+ }
1888+
18811889// TODO, for sret/sret_v check if it actually stores the jlvalue_t's into the
18821890// sret If so, confirm that those values are saved elsewhere in a returnroot
18831891void EnzymeFixupJuliaCallingConvention (LLVMValueRef F_C,
@@ -2339,6 +2347,8 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
23392347 assert (sretCount == Types.size ());
23402348 }
23412349
2350+ auto &DL = F->getParent ()->getDataLayout ();
2351+
23422352 // TODO fix caller side
23432353 for (auto CI : callers) {
23442354 auto Attrs = CI->getAttributes ();
@@ -2400,7 +2410,10 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
24002410
24012411 bool handled = false ;
24022412 if (auto AI = dyn_cast<AllocaInst>(getBaseObject (val, false ))) {
2403- if (AI->getAllocatedType () == Types[sretCount]) {
2413+ if (AI->getAllocatedType () == Types[sretCount] ||
2414+ (isOpaque (AI->getType ()) &&
2415+ DL.getTypeSizeInBits (AI->getAllocatedType ()) ==
2416+ DL.getTypeSizeInBits (Types[sretCount]))) {
24042417 AI->replaceAllUsesWith (gep);
24052418 AI->eraseFromParent ();
24062419 handled = true ;
@@ -2424,8 +2437,9 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C,
24242437 }
24252438
24262439 postCallReplacements.emplace_back (val, gep, Types[sretCount],
2427- sret_jlvalue);
2428- preCallReplacements.emplace_back (val, gep, Types[sretCount]);
2440+ should_sret);
2441+ if (!isWriteOnly (CI, i))
2442+ preCallReplacements.emplace_back (val, gep, Types[sretCount]);
24292443 }
24302444
24312445 if (roots_AT && reroot_enzyme_srets.count (i)) {
0 commit comments