@@ -379,10 +379,10 @@ function nested_codegen!(
379379
380380 target = DefaultCompilerTarget ()
381381 params = PrimalCompilerParams (mode)
382- job = CompilerJob (funcspec, CompilerConfig (target, params; kernel = false ), world)
382+ job = CompilerJob (funcspec, CompilerConfig (target, params; kernel = false , libraries = true , toplevel = true , optimize = false , cleanup = false , only_entry = false , validate = false ), world)
383383
384384 GPUCompiler. prepare_job! (job)
385- otherMod, meta = GPUCompiler. emit_llvm (job; libraries = true , toplevel = true , optimize = false , cleanup = false , only_entry = false , validate = false )
385+ otherMod, meta = GPUCompiler. emit_llvm (job)
386386
387387 prepare_llvm (otherMod, job, meta)
388388
@@ -618,7 +618,9 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType),
618618 prev = addrspacecast! (B, prev, LLVM. PointerType (LLVMType, Derived))
619619 zero_single_allocation (B, Ty, LLVMType, prev, zeroAll, LLVM. ConstantInt (T_int64, 0 ); atomic= true )
620620 else
621- @assert fieldcount (Ty) != 0
621+ if fieldcount (Ty) == 0
622+ error (" Error handling recursive stores for $Ty which has a fieldcount of 0" )
623+ end
622624
623625 T_jlvalue = LLVM. StructType (LLVM. LLVMType[])
624626 T_prjlvalue = LLVM. PointerType (T_jlvalue, Tracked)
@@ -839,16 +841,16 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
839841 continue
840842 end
841843 if isa (ty, LLVM. ArrayType)
842- subTy = if jlty isa DataType
843- eltype (jlty)
844- elseif ! (jlty isa DataType)
845- if eltype (ty) isa LLVM. PointerType && LLVM. addrspace (eltype (ty)) == 10
846- Any
847- else
848- throw (AssertionError (" jlty=$jlty ty=$ty " ))
849- end
850- end
851844 for i = 1 : length (ty)
845+ subTy = if jlty isa DataType
846+ typed_fieldtype (jlty, i)
847+ elseif ! (jlty isa DataType)
848+ if eltype (ty) isa LLVM. PointerType && LLVM. addrspace (eltype (ty)) == 10
849+ Any
850+ else
851+ throw (AssertionError (" jlty=$jlty ty=$ty " ))
852+ end
853+ end
852854 npath = copy (path)
853855 push! (npath, LLVM. ConstantInt (LLVM. IntType (32 ), i - 1 ))
854856 push! (todo, (npath, eltype (ty), subTy))
@@ -866,7 +868,9 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
866868 end
867869 if isa (ty, LLVM. StructType)
868870 i = 1
869- @assert jlty isa DataType
871+ if ! (jlty isa DataType)
872+ throw (AssertionError (" Could not handle non datatype $jlty in zero_single_allocation $ty " ))
873+ end
870874 for ii = 1 : fieldcount (jlty)
871875 jlet = typed_fieldtype (jlty, ii)
872876 if isghostty (jlet) || Core. Compiler. isconstType (jlet)
@@ -2683,6 +2687,7 @@ function lower_convention(
26832687
26842688 RT = LLVM. return_type (entry_ft)
26852689
2690+
26862691 # generate the wrapper function type & definition
26872692 wrapper_types = LLVM. LLVMType[]
26882693 wrapper_attrs = Vector{LLVM. Attribute}[]
@@ -2699,6 +2704,13 @@ function lower_convention(
26992704 sret = sret != = nothing
27002705 returnRoots = returnRoots != = nothing
27012706
2707+ loweredReturn = RetActivity <: Active && (actualRetType === Any)
2708+ if loweredReturn
2709+ @assert ! sret
2710+ @assert ! returnRoots
2711+ RT = convert (LLVMType, eltype (RetActivity))
2712+ end
2713+
27022714 # TODO removed implications
27032715 retRemoved, parmsRemoved = removed_ret_parms (entry_f)
27042716 swiftself = has_swiftself (entry_f)
@@ -2769,8 +2781,8 @@ function lower_convention(
27692781 end
27702782 end
27712783
2772- if length (loweredArgs) == 0 && length (raisedArgs) == 0 && ! sret && ! sret_union
2773- return entry_f, returnRoots, boxedArgs, loweredArgs
2784+ if length (loweredArgs) == 0 && length (raisedArgs) == 0 && ! sret && ! sret_union && ! loweredReturn
2785+ return entry_f, returnRoots, boxedArgs, loweredArgs, actualRetType
27742786 end
27752787
27762788 wrapper_fn = LLVM. name (entry_f)
@@ -3136,28 +3148,69 @@ function lower_convention(
31363148 ret! (builder)
31373149 else
31383150 ctx = LLVM. context (wrapper_f)
3139- push! (
3140- return_attributes (wrapper_f),
3141- StringAttribute (
3142- " enzyme_type" ,
3143- string (typetree (actualRetType, ctx, dl, seen)),
3144- ),
3145- )
3146- push! (
3147- return_attributes (wrapper_f),
3148- StringAttribute (
3149- " enzymejl_parmtype" ,
3150- string (convert (UInt, unsafe_to_pointer (actualRetType))),
3151- ),
3152- )
3153- push! (
3154- return_attributes (wrapper_f),
3155- StringAttribute (
3156- " enzymejl_parmtype_ref" ,
3157- string (UInt (GPUCompiler. BITS_REF)),
3158- ),
3159- )
3160- ret! (builder, res)
3151+
3152+ if loweredReturn
3153+ push! (
3154+ return_attributes (wrapper_f),
3155+ StringAttribute (
3156+ " enzyme_type" ,
3157+ string (typetree (eltype (RetActivity), ctx, dl, seen)),
3158+ ),
3159+ )
3160+ push! (
3161+ return_attributes (wrapper_f),
3162+ StringAttribute (
3163+ " enzymejl_parmtype" ,
3164+ string (convert (UInt, unsafe_to_pointer (eltype (RetActivity)))),
3165+ ),
3166+ )
3167+ push! (
3168+ return_attributes (wrapper_f),
3169+ StringAttribute (
3170+ " enzymejl_parmtype_ref" ,
3171+ string (UInt (GPUCompiler. BITS_VALUE)),
3172+ ),
3173+ )
3174+ ty = emit_jltypeof! (builder, res)
3175+ cmp = icmp! (builder, LLVM. API. LLVMIntEQ, ty, unsafe_to_llvm (builder, eltype (RetActivity)))
3176+ cmpret = BasicBlock (wrapper_f, " ret" )
3177+ failure = BasicBlock (wrapper_f, " fail" )
3178+ br! (builder, cmp, cmpret, failure)
3179+
3180+ position! (builder, cmpret)
3181+ res = bitcast! (builder, res, LLVM. PointerType (RT, addrspace (value_type (res))))
3182+ res = addrspacecast! (builder, res, LLVM. PointerType (RT, Derived))
3183+ res = load! (builder, RT, res)
3184+ ret! (builder, res)
3185+
3186+ position! (builder, failure)
3187+
3188+ emit_error (builder, nothing , " Expected return type of primal to be " * string (eltype (RetActivity))* " but did not find a value of that type" )
3189+ unreachable! (builder)
3190+ else
3191+ push! (
3192+ return_attributes (wrapper_f),
3193+ StringAttribute (
3194+ " enzyme_type" ,
3195+ string (typetree (actualRetType, ctx, dl, seen)),
3196+ ),
3197+ )
3198+ push! (
3199+ return_attributes (wrapper_f),
3200+ StringAttribute (
3201+ " enzymejl_parmtype" ,
3202+ string (convert (UInt, unsafe_to_pointer (actualRetType))),
3203+ ),
3204+ )
3205+ push! (
3206+ return_attributes (wrapper_f),
3207+ StringAttribute (
3208+ " enzymejl_parmtype_ref" ,
3209+ string (UInt (GPUCompiler. BITS_REF)),
3210+ ),
3211+ )
3212+ ret! (builder, res)
3213+ end
31613214 end
31623215 dispose (builder)
31633216 end
@@ -3366,7 +3419,7 @@ function lower_convention(
33663419 end
33673420 throw (LLVM. LLVMException (msg))
33683421 end
3369- return wrapper_f, returnRoots, boxedArgs, loweredArgs
3422+ return wrapper_f, returnRoots, boxedArgs, loweredArgs, loweredReturn ? eltype (RetActivity) : actualRetType
33703423end
33713424
33723425using Random
@@ -3424,9 +3477,20 @@ function GPUCompiler.codegen(
34243477 if parent_job === nothing
34253478 primal_target = DefaultCompilerTarget ()
34263479 primal_params = PrimalCompilerParams (mode)
3480+ config2 = CompilerConfig (
3481+ primal_target,
3482+ primal_params;
3483+ kernel = false ,
3484+ libraries = true ,
3485+ toplevel = toplevel,
3486+ optimize = false ,
3487+ cleanup = false ,
3488+ only_entry = false ,
3489+ validate = false
3490+ )
34273491 primal_job = CompilerJob (
34283492 primal,
3429- CompilerConfig (primal_target, primal_params; kernel = false ) ,
3493+ config2 ,
34303494 job. world,
34313495 )
34323496 else
@@ -3437,12 +3501,18 @@ function GPUCompiler.codegen(
34373501 parent_job. config. entry_abi,
34383502 parent_job. config. name,
34393503 parent_job. config. always_inline,
3504+ libraries = true ,
3505+ toplevel = toplevel,
3506+ optimize = false ,
3507+ cleanup = false ,
3508+ only_entry = false ,
3509+ validate = false ,
34403510 )
34413511 primal_job = CompilerJob (primal, config2, job. world) # TODO EnzymeInterp params, etc
34423512 end
34433513
34443514 GPUCompiler. prepare_job! (primal_job)
3445- mod, meta = GPUCompiler. emit_llvm (primal_job; libraries = true , toplevel = toplevel, optimize = false , cleanup = false , only_entry = false , validate = false )
3515+ mod, meta = GPUCompiler. emit_llvm (primal_job)
34463516 edges = Any[]
34473517 mod_to_edges[mod] = edges
34483518
@@ -4204,7 +4274,7 @@ end
42044274 primalf, returnRoots = primalf, false
42054275
42064276 if lowerConvention
4207- primalf, returnRoots, boxedArgs, loweredArgs = lower_convention (
4277+ primalf, returnRoots, boxedArgs, loweredArgs, actualRetType = lower_convention (
42084278 source_sig,
42094279 mod,
42104280 primalf,
0 commit comments