Skip to content

Commit f518970

Browse files
authored
Improve dup kwarg error and add utility (#2761)
* Improve dup kwarg error and add utility * Fully tested * mixed error fix * forward hints * aug err * Reverse rule error * fix * fix tests * fix * fix * fix * fix * fixup * more improvements * more docs * fix * more fix * fix * print fixup * even better mut error message
1 parent 8bbf178 commit f518970

File tree

15 files changed

+1335
-333
lines changed

15 files changed

+1335
-333
lines changed

ext/EnzymeChainRulesCoreExt.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ function Enzyme._import_rrule(fn, tys...)
159159
# ) : nothing
160160
# end)
161161

162+
ptys = []
162163
for (i, ty) in enumerate(tys)
163164
push!(nothings, :(nothing))
164165
val = Symbol("arg_$i")
165166
TA = Symbol("AN_$i")
166167
e = :($val::$TA)
168+
push!(ptys, :(::$(esc(ty))))
167169
push!(anns, :($TA <: Annotation{<:$(esc(ty))}))
168170
push!(vals, val)
169171
push!(exprs, e)
@@ -184,46 +186,57 @@ function Enzyme._import_rrule(fn, tys...)
184186
end)
185187
end
186188

187-
188189
quote
190+
EnzymeRules.has_easy_rule(::$(esc(fn)), $(ptys...)) = true
191+
189192
function EnzymeRules.augmented_primal(config, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
190193
$(valtys...)
191-
192-
res, pullback = if RetAnnotation <: Const
193-
(fn.val($(primals...); kwargs...), nothing)
194-
else
195-
$ChainRulesCore.rrule(fn.val, $(primals...); kwargs...)
196-
end
194+
195+
@assert !(RetAnnotation <: Const)
196+
res, pullback = $ChainRulesCore.rrule(fn.val, $(primals...); kwargs...)
197197

198198
primal = if EnzymeRules.needs_primal(config)
199199
res
200200
else
201201
nothing
202202
end
203203

204-
shadow = if !EnzymeRules.needs_shadow(config)
205-
nothing
206-
else
207-
if EnzymeRules.width(config) == 1
204+
shadow, byref = if !EnzymeRules.needs_shadow(config)
205+
nothing, Val(false)
206+
elseif !Enzyme.Compiler.guaranteed_nonactive(Core.Typeof(res))
207+
(if EnzymeRules.width(config) == 1
208+
Ref(Enzyme.make_zero(res))
209+
else
210+
ntuple(Val(EnzymeRules.width(config))) do j
211+
Base.@_inline_meta
212+
Ref(Enzyme.make_zero(res))
213+
end
214+
end, Val(true))
215+
else
216+
(if EnzymeRules.width(config) == 1
208217
Enzyme.make_zero(res)
209218
else
210219
ntuple(Val(EnzymeRules.width(config))) do j
211220
Base.@_inline_meta
212221
Enzyme.make_zero(res)
213222
end
214-
end
223+
end, Val(false))
215224
end
216225

217-
return EnzymeRules.AugmentedReturn(primal, shadow, (shadow, pullback))
226+
cache = (shadow, pullback, byref)
227+
return EnzymeRules.augmented_rule_return_type(config, RetAnnotation){typeof(cache)}(primal, shadow, cache)
218228
end
219229

220230
function EnzymeRules.reverse(config, fn::FA, ::Type{RetAnnotation}, tape::TapeTy, $(exprs...); kwargs...) where {RetAnnotation, TapeTy, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
221231
if !(RetAnnotation <: Const)
222-
shadow, pullback = tape
232+
shadow, pullback, byref = tape
223233

224234
tcomb = ntuple(Val(EnzymeRules.width(config))) do batch_i
225235
Base.@_inline_meta
226236
shad = EnzymeRules.width(config) == 1 ? shadow : shadow[batch_i]
237+
if byref === Val(true)
238+
shad = shad[]
239+
end
227240
res = pullback(shad)
228241

229242
for (cr, en) in zip(res, (fn, $(vals...),))

lib/EnzymeCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeCore"
22
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4-
version = "0.8.15"
4+
version = "0.8.16"
55

66
[compat]
77
Adapt = "3, 4"

lib/EnzymeCore/src/easyrules.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,11 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
236236
end
237237

238238
if !seen
239-
push!(gensetup, Expr(:(=), outexpr, nothing))
239+
ST = $(esc(:RT))
240+
if ST <: Tuple
241+
ST = ST.parameters[o]
242+
end
243+
push!(gensetup, Expr(:(=), outexpr, ST))
240244
seen = true
241245
end
242246

@@ -473,21 +477,25 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
473477
end
474478
push!(gensetup, Expr(:(=), :cache, Expr(:tuple, caches...)))
475479

480+
PT = EnzymeRules.primal_type(config, ($(esc(:RTA))).parameters[1])
481+
ST = EnzymeRules.shadow_type(config, ($(esc(:RTA))).parameters[1])
482+
AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT,$ST,typeof(cache)})
483+
476484
genres = if needs_primal(config)
477485
if needs_shadow(config)
478486
if width(config) == 1
479-
Expr(:call, EnzymeRules.AugmentedReturn, , :dΩ, :cache)
487+
Expr(:call, AugmentedReturnType, , :dΩ, :cache)
480488
else
481-
Expr(:call, EnzymeRules.AugmentedReturn, , :dΩ, :cache)
489+
Expr(:call, AugmentedReturnType, , :dΩ, :cache)
482490
end
483491
else
484-
Expr(:call, EnzymeRules.AugmentedReturn, , nothing, :cache)
492+
Expr(:call, AugmentedReturnType, , nothing, :cache)
485493
end
486494
else
487495
if needs_shadow(config)
488-
Expr(:call, EnzymeRules.AugmentedReturn, nothing, :dΩ, :cache)
496+
Expr(:call, AugmentedReturnType, nothing, :dΩ, :cache)
489497
else
490-
Expr(:call, EnzymeRules.AugmentedReturn, nothing, nothing, :cache)
498+
Expr(:call, AugmentedReturnType, nothing, nothing, :cache)
491499
end
492500
end
493501

@@ -613,7 +621,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
613621

614622
if !seen
615623
if inp_types[inum] <: Active
616-
push!(gensetup, Expr(:(=), inexpr, nothing))
624+
push!(gensetup, Expr(:(=), inexpr, eltype(inp_types[inum])))
617625
else
618626
dexpr = Expr(:call, getfield, Symbol(inp_names[inum]), 2)
619627
if W != 1

0 commit comments

Comments
 (0)