33
44function has_easy_rule end
55
6- function has_easy_rule_from_sig (@nospecialize (TT);
7- world:: UInt = Base. get_world_counter (),
8- method_table:: Union{Nothing,Core.Compiler.MethodTableView} = nothing ,
9- caller:: Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult} = nothing )
6+ function has_easy_rule_from_sig (
7+ @nospecialize (TT);
8+ world:: UInt = Base. get_world_counter (),
9+ method_table:: Union{Nothing, Core.Compiler.MethodTableView} = nothing ,
10+ caller:: Union{Nothing, Core.MethodInstance, Core.Compiler.MethodLookupResult} = nothing
11+ )
1012 return isapplicable (has_easy_rule, TT; world, method_table, caller)
1113end
1214
@@ -157,7 +159,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
157159
158160 arg_names = Symbol[]
159161 for sname in input_names
160- rname = Symbol (String (sname)[length (" ann_" )+ 1 : end ])
162+ rname = Symbol (String (sname)[( length (" ann_" ) + 1 ) : end ])
161163 push! (arg_names, rname)
162164 push! (exprs, Expr (:(= ), rname, :($ sname. val)))
163165 end
@@ -172,7 +174,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
172174 if Meta. isexpr (p, :macrocall ) && p. args[1 ] == Symbol (" @Constant" )
173175 continue
174176 end
175- push! (tosum, (i , sname, p))
177+ push! (tosum, (i, sname, p))
176178 end
177179 end
178180
@@ -186,7 +188,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
186188 return @strip_linenos quote
187189 # _ is the input derivative w.r.t. function internals. since we do not
188190 # allow closures/functors with @easy_rule, it is always ignored
189- @generated function EnzymeCore. EnzymeRules. forward ($ (esc (:config )), $ (esc (:fn )):: Const{<:$(Core.Typeof)($f)} , :: Type{<:Annotation{$(esc(:RT))}} , $ (inputs... )) where $ (esc (:RT ))
191+ @generated function EnzymeCore. EnzymeRules. forward ($ (esc (:config )), $ (esc (:fn )):: Const{<:$(Core.Typeof)($f)} , :: Type{<:Annotation{$(esc(:RT))}} , $ (inputs... )) where { $ (esc (:RT ))}
190192 genexprs = Expr[$ (exprs... ,). .. ]
191193 gensetup = Expr[$ (setup_stmts... ,). .. ]
192194
@@ -222,7 +224,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
222224 dval = Expr (:call , getfield, dval, w)
223225 end
224226
225- pname = Symbol (" partial_" , string (o), " _" , string (i), " _" , sname)
227+ pname = Symbol (" partial_" , string (o), " _" , string (i), " _" , sname)
226228 if ! visited[o, i]
227229
228230 # Descend through the rule to see if any users require the original result, Ω
@@ -328,16 +330,21 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
328330 ann_names = Symbol[]
329331 arg_names = Symbol[]
330332 for (i, sname) in enumerate (input_names)
331- rname = Symbol (String (sname)[length (" ann_" )+ 1 : end ])
333+ rname = Symbol (String (sname)[( length (" ann_" ) + 1 ) : end ])
332334 push! (ann_names, sname)
333335 push! (arg_names, rname)
334336 push! (exprs, Expr (:(= ), rname, Expr (:call , getfield, sname, :(:val ))))
335- push! (revexprs, Expr (:(= ), rname,
336- Expr (:if ,
337- Expr (:call , Base. isa, :(cache[($ i)]), Nothing),
338- Expr (:call , getfield, sname, :(:val )),
339- :(cache[($ i)])
340- )))
337+ push! (
338+ revexprs, Expr (
339+ :(= ), rname,
340+ Expr (
341+ :if ,
342+ Expr (:call , Base. isa, :(cache[($ i)]), Nothing),
343+ Expr (:call , getfield, sname, :(:val )),
344+ :(cache[($ i)])
345+ )
346+ )
347+ )
341348 end
342349
343350 tosum0 = Vector{Tuple{Int, Symbol, Any}}[]
@@ -350,7 +357,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
350357 if Meta. isexpr (p, :macrocall ) && p. args[1 ] == Symbol (" @Constant" )
351358 continue
352359 end
353- push! (tosum, (i , sname, p))
360+ push! (tosum, (i, sname, p))
354361 end
355362 end
356363
@@ -361,11 +368,11 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
361368
362369 N = length (inputs)
363370
364- @strip_linenos quote
371+ return @strip_linenos quote
365372
366373 # _ is the input derivative w.r.t. function internals. since we do not
367374 # allow closures/functors with @scalar_rule, it is always ignored
368- @generated function EnzymeCore. EnzymeRules. augmented_primal ($ (esc (:config )), $ (esc (:fn )):: Const{<:$(Core.Typeof)($f)} , $ (esc (:RTA )):: Type{<:Annotation{$(esc(:RT))}} , $ (inputs... )) where $ (esc (:RT ))
375+ @generated function EnzymeCore. EnzymeRules. augmented_primal ($ (esc (:config )), $ (esc (:fn )):: Const{<:$(Core.Typeof)($f)} , $ (esc (:RTA )):: Type{<:Annotation{$(esc(:RT))}} , $ (inputs... )) where { $ (esc (:RT ))}
369376 genexprs = Expr[$ (exprs... ,). .. ]
370377 gensetup = Expr[]
371378
@@ -434,7 +441,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
434441 continue
435442 end
436443
437- if ! EnzymeRules. overwritten (config)[inum+ 1 ]
444+ if ! EnzymeRules. overwritten (config)[inum + 1 ]
438445 push! (caches, nothing )
439446 continue
440447 end
@@ -465,11 +472,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
465472 if used == nothing
466473 push! (caches, nothing )
467474 else
468- push! (caches, Expr (:if ,
469- used,
470- Expr (:call , Base. copy, Symbol (sym_name)),
471- nothing
472- ))
475+ push! (
476+ caches, Expr (
477+ :if ,
478+ used,
479+ Expr (:call , Base. copy, Symbol (sym_name)),
480+ nothing
481+ )
482+ )
473483 end
474484 end
475485 if needs_shadow (config)
@@ -556,7 +566,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
556566 # if eltype(RTA) <: Complex
557567 # push!(genexprs, Expr(:(=), :dΩ, Expr(:call, Base.conj, :dΩ)))
558568 # end
559- elseif RTA <: Type{<:Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated, EnzymeCore.BatchDuplicatedNoNeed}}
569+ elseif RTA <: Type{<:Union{EnzymeCore.DuplicatedNoNeed, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated, EnzymeCore.BatchDuplicatedNoNeed}}
560570 push! (genexprs, Expr (:(= ), :dΩ , :(cache[end ])))
561571 else
562572 push! (genexprs, Expr (Base. throw, AssertionError (" Easy Rule should never be provided a constant reverse seed" )))
@@ -591,7 +601,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
591601 continue
592602 end
593603
594- pname = Symbol (" partial_" , string (o), " _" , string (i), " _" , sname)
604+ pname = Symbol (" partial_" , string (o), " _" , string (i), " _" , sname)
595605 if ! visited[o, i]
596606
597607 # Descend through the rule to see if any users require the original result, Ω
@@ -638,13 +648,12 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
638648 end
639649
640650
641-
642651 if ! seen && inp_types[inum] <: Active
643652 for w in 1 : W
644653 inexpr = Symbol (" insym_" , string (inum), " _" , string (w))
645654 insyms[inum, w] = inexpr
646655
647- push! (gensetup, Expr (:(= ), inexpr, Expr (:call , EnzymeCore. make_zero, Expr (:call , getfield, Symbol (inp_names[inum]), 1 ) )))
656+ push! (gensetup, Expr (:(= ), inexpr, Expr (:call , EnzymeCore. make_zero, Expr (:call , getfield, Symbol (inp_names[inum]), 1 ))))
648657 end
649658 end
650659
@@ -763,7 +772,7 @@ macro easy_rule(call, maybe_setup, partials...)
763772 rrule_expr = scalar_rrule_expr (__source__, f, call, setup_stmts, inputs, input_names, partials)
764773
765774 # Final return: building the expression to insert in the place of this macro
766- quote
775+ return quote
767776 EnzymeRules. has_easy_rule (:: Core.Typeof ($ f), $ (normal_inputs... )) = true
768777 $ (frule_expr)
769778 $ (rrule_expr)
0 commit comments