Skip to content

Commit d3b6fac

Browse files
authored
Support lowered return (#2323)
1 parent bcf2868 commit d3b6fac

File tree

1 file changed

+75
-26
lines changed

1 file changed

+75
-26
lines changed

src/compiler.jl

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,7 @@ function lower_convention(
26852685

26862686
RT = LLVM.return_type(entry_ft)
26872687

2688+
26882689
# generate the wrapper function type & definition
26892690
wrapper_types = LLVM.LLVMType[]
26902691
wrapper_attrs = Vector{LLVM.Attribute}[]
@@ -2701,6 +2702,13 @@ function lower_convention(
27012702
sret = sret !== nothing
27022703
returnRoots = returnRoots !== nothing
27032704

2705+
loweredReturn = RetActivity <: Active && (actualRetType === Any)
2706+
if loweredReturn
2707+
@assert !sret
2708+
@assert !returnRoots
2709+
RT = convert(LLVMType, eltype(RetActivity))
2710+
end
2711+
27042712
# TODO removed implications
27052713
retRemoved, parmsRemoved = removed_ret_parms(entry_f)
27062714
swiftself = has_swiftself(entry_f)
@@ -2771,8 +2779,8 @@ function lower_convention(
27712779
end
27722780
end
27732781

2774-
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union
2775-
return entry_f, returnRoots, boxedArgs, loweredArgs
2782+
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union && !loweredReturn
2783+
return entry_f, returnRoots, boxedArgs, loweredArgs, actualRetType
27762784
end
27772785

27782786
wrapper_fn = LLVM.name(entry_f)
@@ -3138,28 +3146,69 @@ function lower_convention(
31383146
ret!(builder)
31393147
else
31403148
ctx = LLVM.context(wrapper_f)
3141-
push!(
3142-
return_attributes(wrapper_f),
3143-
StringAttribute(
3144-
"enzyme_type",
3145-
string(typetree(actualRetType, ctx, dl, seen)),
3146-
),
3147-
)
3148-
push!(
3149-
return_attributes(wrapper_f),
3150-
StringAttribute(
3151-
"enzymejl_parmtype",
3152-
string(convert(UInt, unsafe_to_pointer(actualRetType))),
3153-
),
3154-
)
3155-
push!(
3156-
return_attributes(wrapper_f),
3157-
StringAttribute(
3158-
"enzymejl_parmtype_ref",
3159-
string(UInt(GPUCompiler.BITS_REF)),
3160-
),
3161-
)
3162-
ret!(builder, res)
3149+
3150+
if loweredReturn
3151+
push!(
3152+
return_attributes(wrapper_f),
3153+
StringAttribute(
3154+
"enzyme_type",
3155+
string(typetree(eltype(RetActivity), ctx, dl, seen)),
3156+
),
3157+
)
3158+
push!(
3159+
return_attributes(wrapper_f),
3160+
StringAttribute(
3161+
"enzymejl_parmtype",
3162+
string(convert(UInt, unsafe_to_pointer(eltype(RetActivity)))),
3163+
),
3164+
)
3165+
push!(
3166+
return_attributes(wrapper_f),
3167+
StringAttribute(
3168+
"enzymejl_parmtype_ref",
3169+
string(UInt(GPUCompiler.BITS_VALUE)),
3170+
),
3171+
)
3172+
ty = emit_jltypeof!(builder, res)
3173+
cmp = icmp!(builder, LLVM.API.LLVMIntEQ, ty, unsafe_to_llvm(builder, eltype(RetActivity)))
3174+
cmpret = BasicBlock(wrapper_f, "ret")
3175+
failure = BasicBlock(wrapper_f, "fail")
3176+
br!(builder, cmp, cmpret, failure)
3177+
3178+
position!(builder, cmpret)
3179+
res = bitcast!(builder, res, LLVM.PointerType(RT, addrspace(value_type(res))))
3180+
res = addrspacecast!(builder, res, LLVM.PointerType(RT, Derived))
3181+
res = load!(builder, RT, res)
3182+
ret!(builder, res)
3183+
3184+
position!(builder, failure)
3185+
3186+
emit_error(builder, nothing, "Expected return type of primal to be "*string(eltype(RetActivity))*" but did not find a value of that type")
3187+
unreachable!(builder)
3188+
else
3189+
push!(
3190+
return_attributes(wrapper_f),
3191+
StringAttribute(
3192+
"enzyme_type",
3193+
string(typetree(actualRetType, ctx, dl, seen)),
3194+
),
3195+
)
3196+
push!(
3197+
return_attributes(wrapper_f),
3198+
StringAttribute(
3199+
"enzymejl_parmtype",
3200+
string(convert(UInt, unsafe_to_pointer(actualRetType))),
3201+
),
3202+
)
3203+
push!(
3204+
return_attributes(wrapper_f),
3205+
StringAttribute(
3206+
"enzymejl_parmtype_ref",
3207+
string(UInt(GPUCompiler.BITS_REF)),
3208+
),
3209+
)
3210+
ret!(builder, res)
3211+
end
31633212
end
31643213
dispose(builder)
31653214
end
@@ -3368,7 +3417,7 @@ function lower_convention(
33683417
end
33693418
throw(LLVM.LLVMException(msg))
33703419
end
3371-
return wrapper_f, returnRoots, boxedArgs, loweredArgs
3420+
return wrapper_f, returnRoots, boxedArgs, loweredArgs, loweredReturn ? eltype(RetActivity) : actualRetType
33723421
end
33733422

33743423
using Random
@@ -4206,7 +4255,7 @@ end
42064255
primalf, returnRoots = primalf, false
42074256

42084257
if lowerConvention
4209-
primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(
4258+
primalf, returnRoots, boxedArgs, loweredArgs, actualRetType = lower_convention(
42104259
source_sig,
42114260
mod,
42124261
primalf,

0 commit comments

Comments
 (0)