Skip to content
Merged
41 changes: 27 additions & 14 deletions ext/EnzymeChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ function Enzyme._import_rrule(fn, tys...)
# ) : nothing
# end)

ptys = []
for (i, ty) in enumerate(tys)
push!(nothings, :(nothing))
val = Symbol("arg_$i")
TA = Symbol("AN_$i")
e = :($val::$TA)
push!(ptys, :(::$(esc(ty))))
push!(anns, :($TA <: Annotation{<:$(esc(ty))}))
push!(vals, val)
push!(exprs, e)
Expand All @@ -184,46 +186,57 @@ function Enzyme._import_rrule(fn, tys...)
end)
end


quote
EnzymeRules.has_easy_rule(::$(esc(fn)), $(ptys...)) = true

function EnzymeRules.augmented_primal(config, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
$(valtys...)

res, pullback = if RetAnnotation <: Const
(fn.val($(primals...); kwargs...), nothing)
else
$ChainRulesCore.rrule(fn.val, $(primals...); kwargs...)
end

@assert !(RetAnnotation <: Const)
res, pullback = $ChainRulesCore.rrule(fn.val, $(primals...); kwargs...)

primal = if EnzymeRules.needs_primal(config)
res
else
nothing
end

shadow = if !EnzymeRules.needs_shadow(config)
nothing
else
if EnzymeRules.width(config) == 1
shadow, byref = if !EnzymeRules.needs_shadow(config)
nothing, Val(false)
elseif !Enzyme.Compiler.guaranteed_nonactive(Core.Typeof(res))
(if EnzymeRules.width(config) == 1
Ref(Enzyme.make_zero(res))
else
ntuple(Val(EnzymeRules.width(config))) do j
Base.@_inline_meta
Ref(Enzyme.make_zero(res))
end
end, Val(true))
else
(if EnzymeRules.width(config) == 1
Enzyme.make_zero(res)
else
ntuple(Val(EnzymeRules.width(config))) do j
Base.@_inline_meta
Enzyme.make_zero(res)
end
end
end, Val(false))
end

return EnzymeRules.AugmentedReturn(primal, shadow, (shadow, pullback))
cache = (shadow, pullback, byref)
return EnzymeRules.augmented_rule_return_type(config, RetAnnotation){typeof(cache)}(primal, shadow, cache)
end

function EnzymeRules.reverse(config, fn::FA, ::Type{RetAnnotation}, tape::TapeTy, $(exprs...); kwargs...) where {RetAnnotation, TapeTy, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
if !(RetAnnotation <: Const)
shadow, pullback = tape
shadow, pullback, byref = tape

tcomb = ntuple(Val(EnzymeRules.width(config))) do batch_i
Base.@_inline_meta
shad = EnzymeRules.width(config) == 1 ? shadow : shadow[batch_i]
if byref === Val(true)
shad = shad[]
end
res = pullback(shad)

for (cr, en) in zip(res, (fn, $(vals...),))
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.15"
version = "0.8.16"

[compat]
Adapt = "3, 4"
Expand Down
22 changes: 15 additions & 7 deletions lib/EnzymeCore/src/easyrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
end

if !seen
push!(gensetup, Expr(:(=), outexpr, nothing))
ST = $(esc(:RT))
if ST <: Tuple
ST = ST.parameters[o]
end
push!(gensetup, Expr(:(=), outexpr, ST))
seen = true
end

Expand Down Expand Up @@ -473,21 +477,25 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
end
push!(gensetup, Expr(:(=), :cache, Expr(:tuple, caches...)))

PT = EnzymeRules.primal_type(config, ($(esc(:RTA))).parameters[1])
ST = EnzymeRules.shadow_type(config, ($(esc(:RTA))).parameters[1])
AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT,$ST,typeof(cache)})

genres = if needs_primal(config)
if needs_shadow(config)
if width(config) == 1
Expr(:call, EnzymeRules.AugmentedReturn, :Ω, :dΩ, :cache)
Expr(:call, AugmentedReturnType, :Ω, :dΩ, :cache)
else
Expr(:call, EnzymeRules.AugmentedReturn, :Ω, :dΩ, :cache)
Expr(:call, AugmentedReturnType, :Ω, :dΩ, :cache)
end
else
Expr(:call, EnzymeRules.AugmentedReturn, :Ω, nothing, :cache)
Expr(:call, AugmentedReturnType, :Ω, nothing, :cache)
end
else
if needs_shadow(config)
Expr(:call, EnzymeRules.AugmentedReturn, nothing, :dΩ, :cache)
Expr(:call, AugmentedReturnType, nothing, :dΩ, :cache)
else
Expr(:call, EnzymeRules.AugmentedReturn, nothing, nothing, :cache)
Expr(:call, AugmentedReturnType, nothing, nothing, :cache)
end
end

Expand Down Expand Up @@ -613,7 +621,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names

if !seen
if inp_types[inum] <: Active
push!(gensetup, Expr(:(=), inexpr, nothing))
push!(gensetup, Expr(:(=), inexpr, eltype(inp_types[inum])))
else
dexpr = Expr(:call, getfield, Symbol(inp_names[inum]), 2)
if W != 1
Expand Down
Loading
Loading