Skip to content

Commit 578e0e3

Browse files
committed
Run runic on changed files
1 parent 2924968 commit 578e0e3

14 files changed

Lines changed: 3750 additions & 3590 deletions

File tree

lib/EnzymeCore/src/easyrules.jl

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33

44
function has_easy_rule end
55

6-
function has_easy_rule_from_sig(@nospecialize(TT);
7-
world::UInt=Base.get_world_counter(),
8-
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
9-
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
6+
function has_easy_rule_from_sig(
7+
@nospecialize(TT);
8+
world::UInt = Base.get_world_counter(),
9+
method_table::Union{Nothing, Core.Compiler.MethodTableView} = nothing,
10+
caller::Union{Nothing, Core.MethodInstance, Core.Compiler.MethodLookupResult} = nothing
11+
)
1012
return isapplicable(has_easy_rule, TT; world, method_table, caller)
1113
end
1214

@@ -157,7 +159,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
157159

158160
arg_names = Symbol[]
159161
for sname in input_names
160-
rname = Symbol(String(sname)[length("ann_")+1:end])
162+
rname = Symbol(String(sname)[(length("ann_") + 1):end])
161163
push!(arg_names, rname)
162164
push!(exprs, Expr(:(=), rname, :($sname.val)))
163165
end
@@ -172,7 +174,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
172174
if Meta.isexpr(p, :macrocall) && p.args[1] == Symbol("@Constant")
173175
continue
174176
end
175-
push!(tosum, (i , sname, p))
177+
push!(tosum, (i, sname, p))
176178
end
177179
end
178180

@@ -186,7 +188,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
186188
return @strip_linenos quote
187189
# _ is the input derivative w.r.t. function internals. since we do not
188190
# allow closures/functors with @easy_rule, it is always ignored
189-
@generated function EnzymeCore.EnzymeRules.forward($(esc(:config)), $(esc(:fn))::Const{<:$(Core.Typeof)($f)}, ::Type{<:Annotation{$(esc(:RT))}}, $(inputs...)) where $(esc(:RT))
191+
@generated function EnzymeCore.EnzymeRules.forward($(esc(:config)), $(esc(:fn))::Const{<:$(Core.Typeof)($f)}, ::Type{<:Annotation{$(esc(:RT))}}, $(inputs...)) where {$(esc(:RT))}
190192
genexprs = Expr[$(exprs...,)...]
191193
gensetup = Expr[$(setup_stmts...,)...]
192194

@@ -222,7 +224,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names
222224
dval = Expr(:call, getfield, dval, w)
223225
end
224226

225-
pname = Symbol("partial_", string(o), "_", string(i), "_", sname)
227+
pname = Symbol("partial_", string(o), "_", string(i), "_", sname)
226228
if !visited[o, i]
227229

228230
# Descend through the rule to see if any users require the original result, Ω
@@ -328,16 +330,21 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
328330
ann_names = Symbol[]
329331
arg_names = Symbol[]
330332
for (i, sname) in enumerate(input_names)
331-
rname = Symbol(String(sname)[length("ann_")+1:end])
333+
rname = Symbol(String(sname)[(length("ann_") + 1):end])
332334
push!(ann_names, sname)
333335
push!(arg_names, rname)
334336
push!(exprs, Expr(:(=), rname, Expr(:call, getfield, sname, :(:val))))
335-
push!(revexprs, Expr(:(=), rname,
336-
Expr(:if,
337-
Expr(:call, Base.isa, :(cache[($i)]), Nothing),
338-
Expr(:call, getfield, sname, :(:val)),
339-
:(cache[($i)])
340-
)))
337+
push!(
338+
revexprs, Expr(
339+
:(=), rname,
340+
Expr(
341+
:if,
342+
Expr(:call, Base.isa, :(cache[($i)]), Nothing),
343+
Expr(:call, getfield, sname, :(:val)),
344+
:(cache[($i)])
345+
)
346+
)
347+
)
341348
end
342349

343350
tosum0 = Vector{Tuple{Int, Symbol, Any}}[]
@@ -350,7 +357,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
350357
if Meta.isexpr(p, :macrocall) && p.args[1] == Symbol("@Constant")
351358
continue
352359
end
353-
push!(tosum, (i , sname, p))
360+
push!(tosum, (i, sname, p))
354361
end
355362
end
356363

@@ -361,11 +368,11 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
361368

362369
N = length(inputs)
363370

364-
@strip_linenos quote
371+
return @strip_linenos quote
365372

366373
# _ is the input derivative w.r.t. function internals. since we do not
367374
# allow closures/functors with @scalar_rule, it is always ignored
368-
@generated function EnzymeCore.EnzymeRules.augmented_primal($(esc(:config)), $(esc(:fn))::Const{<:$(Core.Typeof)($f)}, $(esc(:RTA))::Type{<:Annotation{$(esc(:RT))}}, $(inputs...)) where $(esc(:RT))
375+
@generated function EnzymeCore.EnzymeRules.augmented_primal($(esc(:config)), $(esc(:fn))::Const{<:$(Core.Typeof)($f)}, $(esc(:RTA))::Type{<:Annotation{$(esc(:RT))}}, $(inputs...)) where {$(esc(:RT))}
369376
genexprs = Expr[$(exprs...,)...]
370377
gensetup = Expr[]
371378

@@ -434,7 +441,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
434441
continue
435442
end
436443

437-
if !EnzymeRules.overwritten(config)[inum+1]
444+
if !EnzymeRules.overwritten(config)[inum + 1]
438445
push!(caches, nothing)
439446
continue
440447
end
@@ -465,11 +472,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
465472
if used == nothing
466473
push!(caches, nothing)
467474
else
468-
push!(caches, Expr(:if,
469-
used,
470-
Expr(:call, Base.copy, Symbol(sym_name)),
471-
nothing
472-
))
475+
push!(
476+
caches, Expr(
477+
:if,
478+
used,
479+
Expr(:call, Base.copy, Symbol(sym_name)),
480+
nothing
481+
)
482+
)
473483
end
474484
end
475485
if needs_shadow(config)
@@ -556,7 +566,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
556566
#if eltype(RTA) <: Complex
557567
# push!(genexprs, Expr(:(=), :dΩ, Expr(:call, Base.conj, :dΩ)))
558568
#end
559-
elseif RTA <: Type{<:Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated, EnzymeCore.BatchDuplicatedNoNeed}}
569+
elseif RTA <: Type{<:Union{EnzymeCore.DuplicatedNoNeed, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated, EnzymeCore.BatchDuplicatedNoNeed}}
560570
push!(genexprs, Expr(:(=), :dΩ, :(cache[end])))
561571
else
562572
push!(genexprs, Expr(Base.throw, AssertionError("Easy Rule should never be provided a constant reverse seed")))
@@ -591,7 +601,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
591601
continue
592602
end
593603

594-
pname = Symbol("partial_", string(o), "_", string(i), "_", sname)
604+
pname = Symbol("partial_", string(o), "_", string(i), "_", sname)
595605
if !visited[o, i]
596606

597607
# Descend through the rule to see if any users require the original result, Ω
@@ -638,13 +648,12 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
638648
end
639649

640650

641-
642651
if !seen && inp_types[inum] <: Active
643652
for w in 1:W
644653
inexpr = Symbol("insym_", string(inum), "_", string(w))
645654
insyms[inum, w] = inexpr
646655

647-
push!(gensetup, Expr(:(=), inexpr, Expr(:call, EnzymeCore.make_zero, Expr(:call, getfield, Symbol(inp_names[inum]), 1) )))
656+
push!(gensetup, Expr(:(=), inexpr, Expr(:call, EnzymeCore.make_zero, Expr(:call, getfield, Symbol(inp_names[inum]), 1))))
648657
end
649658
end
650659

@@ -763,7 +772,7 @@ macro easy_rule(call, maybe_setup, partials...)
763772
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names, partials)
764773

765774
# Final return: building the expression to insert in the place of this macro
766-
quote
775+
return quote
767776
EnzymeRules.has_easy_rule(::Core.Typeof($f), $(normal_inputs...)) = true
768777
$(frule_expr)
769778
$(rrule_expr)

lib/EnzymeTestUtils/test/to_vec.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
@testset "array of floats" begin
3434
@testset for T in (Float32, Float64, ComplexF32, ComplexF64),
35-
sz in (2, (2, 3), (2, 3, 4))
35+
sz in (2, (2, 3), (2, 3, 4))
3636

3737
test_to_vec(randn(T, sz))
3838
end
@@ -94,7 +94,7 @@ end
9494

9595
@testset "nested array" begin
9696
@testset for T in (Float32, Float64, ComplexF32, ComplexF64),
97-
sz in (2, (2, 3), (2, 3, 4))
97+
sz in (2, (2, 3), (2, 3, 4))
9898

9999
test_to_vec([randn(T, sz) for _ in 1:10])
100100
end
@@ -121,7 +121,7 @@ end
121121
end
122122

123123
@testset "namedtuple" begin
124-
x = (x="bar", y=randn(3), z=randn(), w=TestStruct(4.0, randn(2)))
124+
x = (x = "bar", y = randn(3), z = randn(), w = TestStruct(4.0, randn(2)))
125125
test_to_vec(x)
126126
@test to_vec(x)[1] == vcat(x.y, x.z, x.w.x, x.w.a)
127127
end

0 commit comments

Comments
 (0)