diff --git a/src/compiler.jl b/src/compiler.jl index f97e3c89f1..0cf75eb06f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1997,6 +1997,17 @@ end include("rules/allocrules.jl") include("rules/llvmrules.jl") +function add_one_in_place(x) + if x isa Base.RefValue + x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) + elseif x isa (Array{T,0} where T) + x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) + else + throw(EnzymeNonScalarReturnException(x, "")) + end + return nothing +end + for (k, v) in ( ("enz_runtime_newtask_fwd", Enzyme.Compiler.runtime_newtask_fwd), ("enz_runtime_newtask_augfwd", Enzyme.Compiler.runtime_newtask_augfwd), @@ -2018,6 +2029,7 @@ for (k, v) in ( ("enz_runtime_jl_setfield_rev", Enzyme.Compiler.rt_jl_setfield_rev), ("enz_runtime_error_if_differentiable", Enzyme.Compiler.error_if_differentiable), ("enz_runtime_error_if_active", Enzyme.Compiler.error_if_active), + ("enz_add_one_in_place", Enzyme.Compiler.add_one_in_place), ) JuliaEnzymeNameMap[k] = v end @@ -5072,7 +5084,7 @@ end if !(primal_target isa GPUCompiler.NativeCompilerTarget) reinsert_gcmarker!(adjointf) augmented_primalf !== nothing && reinsert_gcmarker!(augmented_primalf) - post_optimze!(mod, target_machine, false) #=machine=# + post_optimize!(mod, target_machine, false) #=machine=# end adjointf = functions(mod)[adjointf_name] @@ -5236,17 +5248,6 @@ include("typeutils/recursive_add.jl") end end -function add_one_in_place(x) - if x isa Base.RefValue - x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) - elseif x isa (Array{T,0} where T) - x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x)))) - else - throw(EnzymeNonScalarReturnException(x, "")) - end - return nothing -end - @generated function enzyme_call( ::Val{RawCall}, fptr::PT, @@ -5814,7 +5815,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri if DumpPrePostOpt[] API.EnzymeDumpModuleRef(mod.ref) end - post_optimze!(mod, JIT.get_tm()) + post_optimize!(mod, JIT.get_tm()) if DumpPostOpt[] API.EnzymeDumpModuleRef(mod.ref) end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index a8d6cc6d76..bf73b9f955 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -718,7 +718,7 @@ function addJuliaLegalizationPasses!(pm::LLVM.ModulePassManager, tm::LLVM.Target end end -function post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true) +function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true) addr13NoAlias(mod) removeDeadArgs!(mod, tm) for f in collect(functions(mod)) @@ -764,6 +764,14 @@ function post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = LLVM.run!(pm, mod) end end + for f in functions(mod) + if isempty(blocks(f)) + continue + end + if !has_fn_attr(f, StringAttribute("frame-pointer")) + push!(function_attributes(f), StringAttribute("frame-pointer", "all")) + end + end # @safe_show "post_mod", mod # flush(stdout) # flush(stderr) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 4f3f2949f4..b70c81f1d1 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -74,7 +74,7 @@ function reflect( mod, meta = GPUCompiler.codegen(:llvm, job) #= validate=false =# if second_stage - post_optimze!(mod, JIT.get_tm()) + post_optimize!(mod, JIT.get_tm()) end llvmf = meta.adjointf