@@ -1997,6 +1997,17 @@ end
19971997include (" rules/allocrules.jl" )
19981998include (" rules/llvmrules.jl" )
19991999
2000+ function add_one_in_place (x)
2001+ if x isa Base. RefValue
2002+ x[] = recursive_add (x[], default_adjoint (eltype (Core. Typeof (x))))
2003+ elseif x isa (Array{T,0 } where T)
2004+ x[] = recursive_add (x[], default_adjoint (eltype (Core. Typeof (x))))
2005+ else
2006+ throw (EnzymeNonScalarReturnException (x, " " ))
2007+ end
2008+ return nothing
2009+ end
2010+
20002011for (k, v) in (
20012012 (" enz_runtime_newtask_fwd" , Enzyme. Compiler. runtime_newtask_fwd),
20022013 (" enz_runtime_newtask_augfwd" , Enzyme. Compiler. runtime_newtask_augfwd),
@@ -2018,6 +2029,7 @@ for (k, v) in (
20182029 (" enz_runtime_jl_setfield_rev" , Enzyme. Compiler. rt_jl_setfield_rev),
20192030 (" enz_runtime_error_if_differentiable" , Enzyme. Compiler. error_if_differentiable),
20202031 (" enz_runtime_error_if_active" , Enzyme. Compiler. error_if_active),
2032+ (" enz_add_one_in_place" , Enzyme. Compiler. add_one_in_place),
20212033)
20222034 JuliaEnzymeNameMap[k] = v
20232035end
@@ -5072,7 +5084,7 @@ end
50725084 if ! (primal_target isa GPUCompiler. NativeCompilerTarget)
50735085 reinsert_gcmarker! (adjointf)
50745086 augmented_primalf != = nothing && reinsert_gcmarker! (augmented_primalf)
5075- post_optimze ! (mod, target_machine, false ) #= machine=#
5087+ post_optimize ! (mod, target_machine, false ) #= machine=#
50765088 end
50775089
50785090 adjointf = functions (mod)[adjointf_name]
@@ -5236,17 +5248,6 @@ include("typeutils/recursive_add.jl")
52365248 end
52375249end
52385250
5239- function add_one_in_place (x)
5240- if x isa Base. RefValue
5241- x[] = recursive_add (x[], default_adjoint (eltype (Core. Typeof (x))))
5242- elseif x isa (Array{T,0 } where T)
5243- x[] = recursive_add (x[], default_adjoint (eltype (Core. Typeof (x))))
5244- else
5245- throw (EnzymeNonScalarReturnException (x, " " ))
5246- end
5247- return nothing
5248- end
5249-
52505251@generated function enzyme_call (
52515252 :: Val{RawCall} ,
52525253 fptr:: PT ,
@@ -5814,7 +5815,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
58145815 if DumpPrePostOpt[]
58155816 API. EnzymeDumpModuleRef (mod. ref)
58165817 end
5817- post_optimze ! (mod, JIT. get_tm ())
5818+ post_optimize ! (mod, JIT. get_tm ())
58185819 if DumpPostOpt[]
58195820 API. EnzymeDumpModuleRef (mod. ref)
58205821 end
0 commit comments