@@ -2685,6 +2685,7 @@ function lower_convention(
26852685
26862686 RT = LLVM. return_type (entry_ft)
26872687
2688+
26882689 # generate the wrapper function type & definition
26892690 wrapper_types = LLVM. LLVMType[]
26902691 wrapper_attrs = Vector{LLVM. Attribute}[]
@@ -2701,6 +2702,13 @@ function lower_convention(
27012702 sret = sret != = nothing
27022703 returnRoots = returnRoots != = nothing
27032704
2705+ loweredReturn = RetActivity <: Active && (actualRetType === Any)
2706+ if loweredReturn
2707+ @assert ! sret
2708+ @assert ! returnRoots
2709+ RT = convert (LLVMType, eltype (RetActivity))
2710+ end
2711+
27042712 # TODO removed implications
27052713 retRemoved, parmsRemoved = removed_ret_parms (entry_f)
27062714 swiftself = has_swiftself (entry_f)
@@ -2771,8 +2779,8 @@ function lower_convention(
27712779 end
27722780 end
27732781
2774- if length (loweredArgs) == 0 && length (raisedArgs) == 0 && ! sret && ! sret_union
2775- return entry_f, returnRoots, boxedArgs, loweredArgs
2782+ if length (loweredArgs) == 0 && length (raisedArgs) == 0 && ! sret && ! sret_union && ! loweredReturn
2783+ return entry_f, returnRoots, boxedArgs, loweredArgs, actualRetType
27762784 end
27772785
27782786 wrapper_fn = LLVM. name (entry_f)
@@ -3138,28 +3146,69 @@ function lower_convention(
31383146 ret! (builder)
31393147 else
31403148 ctx = LLVM. context (wrapper_f)
3141- push! (
3142- return_attributes (wrapper_f),
3143- StringAttribute (
3144- " enzyme_type" ,
3145- string (typetree (actualRetType, ctx, dl, seen)),
3146- ),
3147- )
3148- push! (
3149- return_attributes (wrapper_f),
3150- StringAttribute (
3151- " enzymejl_parmtype" ,
3152- string (convert (UInt, unsafe_to_pointer (actualRetType))),
3153- ),
3154- )
3155- push! (
3156- return_attributes (wrapper_f),
3157- StringAttribute (
3158- " enzymejl_parmtype_ref" ,
3159- string (UInt (GPUCompiler. BITS_REF)),
3160- ),
3161- )
3162- ret! (builder, res)
3149+
3150+ if loweredReturn
3151+ push! (
3152+ return_attributes (wrapper_f),
3153+ StringAttribute (
3154+ " enzyme_type" ,
3155+ string (typetree (eltype (RetActivity), ctx, dl, seen)),
3156+ ),
3157+ )
3158+ push! (
3159+ return_attributes (wrapper_f),
3160+ StringAttribute (
3161+ " enzymejl_parmtype" ,
3162+ string (convert (UInt, unsafe_to_pointer (eltype (RetActivity)))),
3163+ ),
3164+ )
3165+ push! (
3166+ return_attributes (wrapper_f),
3167+ StringAttribute (
3168+ " enzymejl_parmtype_ref" ,
3169+ string (UInt (GPUCompiler. BITS_VALUE)),
3170+ ),
3171+ )
3172+ ty = emit_jltypeof! (builder, res)
3173+ cmp = icmp! (builder, LLVM. API. LLVMIntEQ, ty, unsafe_to_llvm (builder, eltype (RetActivity)))
3174+ cmpret = BasicBlock (wrapper_f, " ret" )
3175+ failure = BasicBlock (wrapper_f, " fail" )
3176+ br! (builder, cmp, cmpret, failure)
3177+
3178+ position! (builder, cmpret)
3179+ res = bitcast! (builder, res, LLVM. PointerType (RT, addrspace (value_type (res))))
3180+ res = addrspacecast! (builder, res, LLVM. PointerType (RT, Derived))
3181+ res = load! (builder, RT, res)
3182+ ret! (builder, res)
3183+
3184+ position! (builder, failure)
3185+
3186+ emit_error (builder, nothing , " Expected return type of primal to be " * string (eltype (RetActivity))* " but did not find a value of that type" )
3187+ unreachable! (builder)
3188+ else
3189+ push! (
3190+ return_attributes (wrapper_f),
3191+ StringAttribute (
3192+ " enzyme_type" ,
3193+ string (typetree (actualRetType, ctx, dl, seen)),
3194+ ),
3195+ )
3196+ push! (
3197+ return_attributes (wrapper_f),
3198+ StringAttribute (
3199+ " enzymejl_parmtype" ,
3200+ string (convert (UInt, unsafe_to_pointer (actualRetType))),
3201+ ),
3202+ )
3203+ push! (
3204+ return_attributes (wrapper_f),
3205+ StringAttribute (
3206+ " enzymejl_parmtype_ref" ,
3207+ string (UInt (GPUCompiler. BITS_REF)),
3208+ ),
3209+ )
3210+ ret! (builder, res)
3211+ end
31633212 end
31643213 dispose (builder)
31653214 end
@@ -3368,7 +3417,7 @@ function lower_convention(
33683417 end
33693418 throw (LLVM. LLVMException (msg))
33703419 end
3371- return wrapper_f, returnRoots, boxedArgs, loweredArgs
3420+ return wrapper_f, returnRoots, boxedArgs, loweredArgs, loweredReturn ? eltype (RetActivity) : actualRetType
33723421end
33733422
33743423using Random
@@ -4206,7 +4255,7 @@ end
42064255 primalf, returnRoots = primalf, false
42074256
42084257 if lowerConvention
4209- primalf, returnRoots, boxedArgs, loweredArgs = lower_convention (
4258+ primalf, returnRoots, boxedArgs, loweredArgs, actualRetType = lower_convention (
42104259 source_sig,
42114260 mod,
42124261 primalf,
0 commit comments