Use parallel testing infrastructure from CUDA.jl#2619
Merged
Conversation
Member
vchuravy
commented
Sep 23, 2025
- steal parallel test setup from CUDA
- add Distributed.jl
- small fixes
- fix unused test
- wrap bfloat16 in testset
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/core/abi.jl b/test/core/abi.jl
index e45d5f0..881f94c 100644
--- a/test/core/abi.jl
+++ b/test/core/abi.jl
@@ -17,23 +17,23 @@ end
# GhostType -> Nothing
res = autodiff(Reverse, f, Const, Const(nothing))
@test res === ((nothing,),)
-
+
res = autodiff(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing))
@test res === ((nothing,),)
-
+
@test () === autodiff(Forward, f, Const, Const(nothing))
@test () === autodiff(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing))
res = autodiff(Reverse, f, Const(nothing))
@test res === ((nothing,),)
-
+
@test () === autodiff(Forward, f, Const(nothing))
res = autodiff_deferred(Reverse, Const(f), Const, Const(nothing))
@test res === ((nothing,),)
res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), Const(f), Const, Const(nothing))
@test res === ((nothing,),)
-
+
@test () === autodiff_deferred(Forward, Const(f), Const, Const(nothing))
@test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), Const(f), Const, Const(nothing))
@@ -52,36 +52,36 @@ end
# Complex numbers
@test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im))
- cres, = autodiff(ReverseHolomorphic, f, Active, Active(1.5 + 0.7im))[1]
+ cres, = autodiff(ReverseHolomorphic, f, Active, Active(1.5 + 0.7im))[1]
@test cres ≈ 1.0 + 0.0im
- cres, = autodiff(Forward, f, Duplicated, Duplicated(1.5 + 0.7im, 1.0 + 0im))
+ cres, = autodiff(Forward, f, Duplicated, Duplicated(1.5 + 0.7im, 1.0 + 0im))
@test cres ≈ 1.0 + 0.0im
@test_throws ErrorException autodiff(Reverse, f, Active(1.5 + 0.7im))
- cres, = autodiff(ReverseHolomorphic, f, Active(1.5 + 0.7im))[1]
+ cres, = autodiff(ReverseHolomorphic, f, Active(1.5 + 0.7im))[1]
@test cres ≈ 1.0 + 0.0im
- cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im))
+ cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0 + 0im))
@test cres ≈ 1.0 + 0.0im
@test_throws ErrorException autodiff_deferred(Reverse, Const(f), Active, Active(1.5 + 0.7im))
@test_throws ErrorException autodiff_deferred(ReverseHolomorphic, Const(f), Active, Active(1.5 + 0.7im))
- cres, = autodiff_deferred(Forward, Const(f), Duplicated, Duplicated(1.5 + 0.7im, 1.0+0im))
+ cres, = autodiff_deferred(Forward, Const(f), Duplicated, Duplicated(1.5 + 0.7im, 1.0 + 0im))
@test cres ≈ 1.0 + 0.0im
# Unused singleton argument
unused(_, y) = y
_, res0 = autodiff(Reverse, unused, Active, Const(nothing), Active(2.0))[1]
@test res0 ≈ 1.0
-
+
_, res0 = autodiff(Enzyme.set_abi(Reverse, NonGenABI), unused, Active, Const(nothing), Active(2.0))[1]
@test res0 ≈ 1.0
-
+
res0, = autodiff(Forward, unused, Duplicated, Const(nothing), Duplicated(2.0, 1.0))
@test res0 ≈ 1.0
res0, = autodiff(Forward, unused, Duplicated, Const(nothing), DuplicatedNoNeed(2.0, 1.0))
@test res0 ≈ 1.0
-
+
res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, Duplicated, Const(nothing), Duplicated(2.0, 1.0))
@test res0 ≈ 1.0
@@ -125,12 +125,12 @@ end
pair = autodiff_deferred(Reverse, Const(mul), Active, Active(2.0), Active(3.0))[1]
@test pair[1] ≈ 3.0
@test pair[2] ≈ 2.0
-
+
pair, orig = autodiff(ReverseWithPrimal, mul, Active(2.0), Active(3.0))
@test pair[1] ≈ 3.0
@test pair[2] ≈ 2.0
@test orig ≈ 6.0
-
+
pair, orig = autodiff_deferred(ReverseWithPrimal, Const(mul), Active, Active(2.0), Active(3.0))
@test pair[1] ≈ 3.0
@test pair[2] ≈ 2.0
@@ -148,7 +148,7 @@ end
@test res[] ≈ 6.0
@test dres[] ≈ 2.0
@test orig == Float64
-
+
res = Ref(3.0)
dres = Ref(1.0)
pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace), Const, Duplicated(res, dres))
@@ -156,7 +156,7 @@ end
@test res[] ≈ 6.0
@test dres[] ≈ 2.0
@test orig == Float64
-
+
function inplace2(x)
x[] *= 2
return nothing
@@ -199,55 +199,55 @@ end
end
g(x) = x.qux
- res2, = autodiff(Reverse, g, Active, Active(Foo(3, 1.2)))[1]
+ res2, = autodiff(Reverse, g, Active, Active(Foo(3, 1.2)))[1]
@test res2.qux ≈ 1.0
- @test 1.0≈ first(autodiff(Forward, g, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
+ @test 1.0 ≈ first(autodiff(Forward, g, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
- res2, = autodiff(Reverse, g, Active(Foo(3, 1.2)))[1]
+ res2, = autodiff(Reverse, g, Active(Foo(3, 1.2)))[1]
@test res2.qux ≈ 1.0
- @test 1.0≈ first(autodiff(Forward, g, Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
+ @test 1.0 ≈ first(autodiff(Forward, g, Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
unused2(_, y) = y.qux
_, resF = autodiff(Reverse, unused2, Active, Const(nothing), Active(Foo(3, 2.0)))[1]
@test resF.qux ≈ 1.0
- @test 1.0≈ first(autodiff(Forward, unused2, Duplicated, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
+ @test 1.0 ≈ first(autodiff(Forward, unused2, Duplicated, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
_, resF = autodiff(Reverse, unused2, Const(nothing), Active(Foo(3, 2.0)))[1]
@test resF.qux ≈ 1.0
- @test 1.0≈ first(autodiff(Forward, unused2, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
+ @test 1.0 ≈ first(autodiff(Forward, unused2, Const(nothing), Duplicated(Foo(3, 1.2), Foo(0, 1.0))))
h(x, y) = x.qux * y.qux
res3 = autodiff(Reverse, h, Active, Active(Foo(3, 1.2)), Active(Foo(5, 3.4)))[1]
@test res3[1].qux ≈ 3.4
@test res3[2].qux ≈ 1.2
- @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0))))
+ @test 7 * 3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0))))
res3 = autodiff(Reverse, h, Active(Foo(3, 1.2)), Active(Foo(5, 3.4)))[1]
@test res3[1].qux ≈ 3.4
@test res3[2].qux ≈ 1.2
- @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0))))
+ @test 7 * 3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0))))
caller(f, x) = f(x)
- _, res4 = autodiff(Reverse, caller, Active, Const((x)->x), Active(3.0))[1]
+ _, res4 = autodiff(Reverse, caller, Active, Const((x) -> x), Active(3.0))[1]
@test res4 ≈ 1.0
- res4, = autodiff(Forward, caller, Duplicated, Const((x)->x), Duplicated(3.0, 1.0))
+ res4, = autodiff(Forward, caller, Duplicated, Const((x) -> x), Duplicated(3.0, 1.0))
@test res4 ≈ 1.0
- _, res4 = autodiff(Reverse, caller, Const((x)->x), Active(3.0))[1]
+ _, res4 = autodiff(Reverse, caller, Const((x) -> x), Active(3.0))[1]
@test res4 ≈ 1.0
- res4, = autodiff(Forward, caller, Const((x)->x), Duplicated(3.0, 1.0))
+ res4, = autodiff(Forward, caller, Const((x) -> x), Duplicated(3.0, 1.0))
@test res4 ≈ 1.0
struct LList
- next::Union{LList,Nothing}
+ next::Union{LList, Nothing}
val::Float64
end
@@ -261,7 +261,7 @@ end
end
regular = LList(LList(nothing, 1.0), 2.0)
- shadow = LList(LList(nothing, 0.0), 0.0)
+ shadow = LList(LList(nothing, 0.0), 0.0)
ad = autodiff(Reverse, sumlist, Active, Duplicated(regular, shadow))
@test ad === ((nothing,),)
@test shadow.val ≈ 1.0 && shadow.next.val ≈ 1.0
@@ -274,7 +274,7 @@ end
dx = Ref(0.0)
dy = Ref(0.0)
n = autodiff(Reverse, mulr, Active, Duplicated(x, dx), Duplicated(y, dy))
- @test n === ((nothing,nothing),)
+ @test n === ((nothing, nothing),)
@test dx[] ≈ 3.0
@test dy[] ≈ 2.0
@@ -282,18 +282,18 @@ end
y = Ref(3.0)
dx = Ref(5.0)
dy = Ref(7.0)
- @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, Duplicated, Duplicated(x, dx), Duplicated(y, dy)))
+ @test 5.0 * 3.0 + 2.0 * 7.0 ≈ first(autodiff(Forward, mulr, Duplicated, Duplicated(x, dx), Duplicated(y, dy)))
- _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const((x->x*x,)), Active(2.0))[1]
+ _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const((x -> x * x,)), Active(2.0))[1]
@test mid ≈ 4.0
- _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const([x->x*x]), Active(2.0))[1]
+ _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const([x -> x * x]), Active(2.0))[1]
@test mid ≈ 4.0
- mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const((x->x*x,)), Duplicated(2.0, 1.0))
+ mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const((x -> x * x,)), Duplicated(2.0, 1.0))
@test mid ≈ 4.0
- mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const([x->x*x]), Duplicated(2.0, 1.0))
+ mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), Duplicated, Const([x -> x * x]), Duplicated(2.0, 1.0))
@test mid ≈ 4.0
@@ -315,12 +315,12 @@ unstable_load(x) = Base.inferencebarrier(x)[1]
x = [2.7]
dx = [0.0]
Enzyme.autodiff(Reverse, Const(unstable_load), Active, Duplicated(x, dx))
- @test dx ≈ [1.0]
+ @test dx ≈ [1.0]
x = [2.7]
dx = [0.0]
Enzyme.autodiff_deferred(Reverse, Const(unstable_load), Active, Duplicated(x, dx))
- @test dx ≈ [1.0]
+ @test dx ≈ [1.0]
end
@testset "Mutable Struct ABI" begin
@@ -329,14 +329,14 @@ end
end
function sqMStruct(domain::Vector{MStruct}, x::Float32)
- @inbounds domain[1] = MStruct(x*x)
- return nothing
+ @inbounds domain[1] = MStruct(x * x)
+ return nothing
end
- orig = [MStruct(0.0)]
+ orig = [MStruct(0.0)]
shadow = [MStruct(17.0)]
Enzyme.autodiff(Forward, sqMStruct, Duplicated(orig, shadow), Duplicated(Float32(3.14), Float32(1.0)))
- @test 2.0*3.14 ≈ shadow[1].val
+ @test 2.0 * 3.14 ≈ shadow[1].val
end
@@ -360,10 +360,10 @@ end
@test 2.0 ≈ Enzyme.autodiff(Reverse, f, Active(3.0))[1][1][1]
@test 2.0 ≈ Enzyme.autodiff(Forward, f, Duplicated(3.0, 1.0))[1]
-
+
df = clo2(0.0)
@test 2.0 ≈ Enzyme.autodiff(Reverse, Duplicated(f, df), Active(3.0))[1][1]
- @test 3.0 ≈ df.V[1]
+ @test 3.0 ≈ df.V[1]
@test 2.0 * 7.0 + 3.0 * 5.0 ≈ first(Enzyme.autodiff(Forward, Duplicated(f, df), Duplicated(5.0, 7.0)))
end
@@ -379,25 +379,25 @@ end
forward, pullback0 = Enzyme.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val((true, true, false))), Const{typeof(fwdunion)}, Duplicated, Duplicated{Vector{Float64}}, Const{Bool})
tape, primal, shadow = forward(Const(fwdunion), Duplicated(Float64[2.0], Float64[0.0]), Const(false))
- @test primal ≈ 2.0
- @test shadow[] ≈ 0.0
-
+ @test primal ≈ 2.0
+ @test shadow[] ≈ 0.0
+
forward, pullback1 = Enzyme.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val((true, true, false))), Const{typeof(fwdunion)}, Duplicated, Duplicated{Vector{Float64}}, Const{Bool})
tape, primal, shadow = forward(Const(fwdunion), Duplicated(Float64[2.0], Float64[0.0]), Const(true))
- @test primal == Base._InitialValue()
+ @test primal == Base._InitialValue()
@test shadow == Base._InitialValue()
@test pullback0 == pullback1
-
+
forward, pullback2 = Enzyme.autodiff_thunk(ReverseSplitModified(ReverseSplitNoPrimal, Val((true, true, false))), Const{typeof(fwdunion)}, Duplicated, Duplicated{Vector{Float64}}, Const{Bool})
tape, primal, shadow = forward(Const(fwdunion), Duplicated(Float64[2.0], Float64[0.0]), Const(false))
@test primal == nothing
- @test shadow[] ≈ 0.0
+ @test shadow[] ≈ 0.0
@test pullback0 != pullback2
-
+
forward, pullback3 = Enzyme.autodiff_thunk(ReverseSplitModified(ReverseSplitNoPrimal, Val((true, true, false))), Const{typeof(fwdunion)}, Duplicated, Duplicated{Vector{Float64}}, Const{Bool})
tape, primal, shadow = forward(Const(fwdunion), Duplicated(Float64[2.0], Float64[0.0]), Const(true))
@test primal == nothing
- @test shadow == Base._InitialValue()
+ @test shadow == Base._InitialValue()
@test pullback2 == pullback3
end
@@ -407,11 +407,11 @@ end
end
struct AFoo
- x::Float64
+ x::Float64
end
function (f::AFoo)(x::Float64)
- return f.x * x
+ return f.x * x
end
@test Enzyme.autodiff(Reverse, method, Active, Const(AFoo(2.0)), Active(3.0))[1][2] ≈ 2.0
@@ -424,7 +424,7 @@ end
end
function (f::ABar)(x::Float64)
- return 2.0 * x
+ return 2.0 * x
end
@test Enzyme.autodiff(Reverse, method, Active, Const(ABar()), Active(3.0))[1][2] ≈ 2.0
@@ -438,11 +438,11 @@ end
end
function (c::RWClos)(y)
- c.x[1] *= y
- return y
+ c.x[1] *= y
+ return y
end
- c = RWClos([4.])
+ c = RWClos([4.0])
@test_throws Enzyme.Compiler.EnzymeMutabilityException autodiff(Reverse, c, Active(3.0))
@@ -454,10 +454,10 @@ end
end
function (c::RWClos2)(y)
- return y + c.x[1]
+ return y + c.x[1]
end
- c2 = RWClos2([4.])
+ c2 = RWClos2([4.0])
@test autodiff(Reverse, c2, Active(3.0))[1][1] ≈ 1.0
@test autodiff(Reverse, Const(c2), Active(3.0))[1][1] ≈ 1.0
@@ -465,9 +465,8 @@ end
end
-
@testset "Promotion" begin
- x = [1.0, 2.0]; dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0];
+ x = [1.0, 2.0]; dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]
rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2
r = autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2)))
@test r[2] ≈ 100.0
@@ -491,32 +490,32 @@ end
@testset "Type inference" begin
x = ones(10)
- @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))
- @inferred autodiff(Enzyme.ReverseWithPrimal, abssum, Duplicated(x,x))
- @inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x,x))
- @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x,x))
- @inferred autodiff(Enzyme.Forward, abssum, Duplicated(x,x))
- @inferred autodiff(Enzyme.ForwardWithPrimal, abssum, Duplicated, Duplicated(x,x))
- @inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x,x))
-
+ @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x, x))
+ @inferred autodiff(Enzyme.ReverseWithPrimal, abssum, Duplicated(x, x))
+ @inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x, x))
+ @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x, x))
+ @inferred autodiff(Enzyme.Forward, abssum, Duplicated(x, x))
+ @inferred autodiff(Enzyme.ForwardWithPrimal, abssum, Duplicated, Duplicated(x, x))
+ @inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x, x))
+
@inferred gradient(Reverse, abssum, x)
@inferred gradient!(Reverse, x, abssum, x)
@inferred gradient(ReverseWithPrimal, abssum, x)
@inferred gradient!(ReverseWithPrimal, x, abssum, x)
-
+
cx = ones(10)
- @inferred autodiff(Enzyme.ReverseHolomorphic, sum, Duplicated(cx,cx))
- @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, sum, Duplicated(cx,cx))
- @inferred autodiff(Enzyme.Forward, sum, Duplicated(cx,cx))
-
+ @inferred autodiff(Enzyme.ReverseHolomorphic, sum, Duplicated(cx, cx))
+ @inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, sum, Duplicated(cx, cx))
+ @inferred autodiff(Enzyme.Forward, sum, Duplicated(cx, cx))
+
@inferred Enzyme.make_zero(x)
@inferred Enzyme.make_zero(cx)
-
- tx = (1.0, 2.0, 3.0)
+
+ tx = (1.0, 2.0, 3.0)
@inferred Enzyme.make_zero(tx)
-
+
@inferred gradient(Reverse, abssum, tx)
@inferred gradient(Forward, abssum, tx)
@@ -531,24 +530,24 @@ end
end
function ulogistic(x)
- return x > 36 ? one(x) : 1 / (one(x) + 1/x)
+ return x > 36 ? one(x) : 1 / (one(x) + 1 / x)
end
@noinline function u_transform_tuple(x)
yfirst = ulogistic(@inbounds x[1])
- yfirst, 2
+ return yfirst, 2
end
@noinline function mytransform(ts, x)
yfirst = ulogistic(@inbounds x[1])
yrest, _ = u_transform_tuple(x)
- (yfirst, yrest)
+ return (yfirst, yrest)
end
function undefsret(trf, x)
- p = mytransform(trf, x)
- return 1/(p[2])
+ p = mytransform(trf, x)
+ return 1 / (p[2])
end
@testset "Undef sret" begin
@@ -569,30 +568,30 @@ end
return bref.x[1] .+ bref.v[1]
end
function byrefs(x, v)
- byrefg(ByRefStruct(x, v))
+ return byrefg(ByRefStruct(x, v))
end
@testset "Batched byref struct" begin
- Enzyme.autodiff(Forward, byrefs, BatchDuplicated([1.0], ([1.0], [1.0])), BatchDuplicated([1.0], ([1.0], [1.0]) ) )
+ Enzyme.autodiff(Forward, byrefs, BatchDuplicated([1.0], ([1.0], [1.0])), BatchDuplicated([1.0], ([1.0], [1.0])))
end
-
+
function myunique0()
return Vector{Float64}(undef, 0)
end
@static if VERSION < v"1.11-"
-@testset "Forward mode array construct" begin
- autodiff(Forward, myunique0, Duplicated)
-end
+ @testset "Forward mode array construct" begin
+ autodiff(Forward, myunique0, Duplicated)
+ end
else
-function myunique()
- m = Memory{Float64}.instance
- return Core.memoryref(m)
-end
-@testset "Forward mode array construct" begin
- autodiff(Forward, myunique, Duplicated)
- autodiff(Forward, myunique0, Duplicated)
-end
+ function myunique()
+ m = Memory{Float64}.instance
+ return Core.memoryref(m)
+ end
+ @testset "Forward mode array construct" begin
+ autodiff(Forward, myunique, Duplicated)
+ autodiff(Forward, myunique0, Duplicated)
+ end
end
mutable struct EmptyStruct end
@@ -627,7 +626,7 @@ function f_wb!(c)
end
@testset "Batched Writebarrier" begin
- c = (; a=[ones(4)], b=[3.1*ones(4)])
+ c = (; a = [ones(4)], b = [3.1 * ones(4)])
dc = ntuple(_ -> make_zero(c), Val(2))
autodiff(
@@ -689,7 +688,7 @@ end
struct UnionStruct
- x::Union{Float32,Nothing}
+ x::Union{Float32, Nothing}
y::Any
end
@@ -699,7 +698,7 @@ end
function make_fsq(x)
y = UnionStruct(x, [])
- Base.inferencebarrier(fsq)(y)
+ return Base.inferencebarrier(fsq)(y)
end
@testset "UnionStruct" begin
diff --git a/test/core/activity.jl b/test/core/activity.jl
index 7fe39c9..a44c3e8 100644
--- a/test/core/activity.jl
+++ b/test/core/activity.jl
@@ -13,7 +13,7 @@ end
@testset "Activity Tests" begin
@static if VERSION < v"1.11-"
else
- @test Enzyme.Compiler.active_reg(Memory{Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
+ @test Enzyme.Compiler.active_reg(Memory{Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
end
@test Enzyme.Compiler.active_reg(Type{Array}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
@test Enzyme.Compiler.active_reg(Ints{<:Any, Integer}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
@@ -24,7 +24,7 @@ end
@test Enzyme.Compiler.active_reg(Ints{Integer, Float64}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
@test Enzyme.Compiler.active_reg(MInts{Integer, Float64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
- @test Enzyme.Compiler.active_reg(Tuple{Float32,Float32,Int}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
+ @test Enzyme.Compiler.active_reg(Tuple{Float32, Float32, Int}, Base.get_world_counter()) == Enzyme.Compiler.ActiveState
@test Enzyme.Compiler.active_reg(Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}, Base.get_world_counter()) == Enzyme.Compiler.AnyState
@test Enzyme.Compiler.active_reg(Base.RefValue{Float32}, Base.get_world_counter()) == Enzyme.Compiler.DupState
@test Enzyme.Compiler.active_reg(Ptr, Base.get_world_counter()) == Enzyme.Compiler.DupState
@@ -32,13 +32,13 @@ end
@test Enzyme.Compiler.active_reg(Colon, Base.get_world_counter()) == Enzyme.Compiler.AnyState
@test Enzyme.Compiler.active_reg(Symbol, Base.get_world_counter()) == Enzyme.Compiler.AnyState
@test Enzyme.Compiler.active_reg(String, Base.get_world_counter()) == Enzyme.Compiler.AnyState
- @test Enzyme.Compiler.active_reg(Tuple{Any,Int64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
- @test Enzyme.Compiler.active_reg(Tuple{S,Int64} where S, Base.get_world_counter()) == Enzyme.Compiler.DupState
- @test Enzyme.Compiler.active_reg(Union{Float64,Nothing}, Base.get_world_counter()) == Enzyme.Compiler.DupState
- @test Enzyme.Compiler.active_reg(Union{Float64,Nothing}, Base.get_world_counter(), UnionSret=true) == Enzyme.Compiler.ActiveState
+ @test Enzyme.Compiler.active_reg(Tuple{Any, Int64}, Base.get_world_counter()) == Enzyme.Compiler.DupState
+ @test Enzyme.Compiler.active_reg(Tuple{S, Int64} where {S}, Base.get_world_counter()) == Enzyme.Compiler.DupState
+ @test Enzyme.Compiler.active_reg(Union{Float64, Nothing}, Base.get_world_counter()) == Enzyme.Compiler.DupState
+ @test Enzyme.Compiler.active_reg(Union{Float64, Nothing}, Base.get_world_counter(), UnionSret = true) == Enzyme.Compiler.ActiveState
@test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter()) == Enzyme.Compiler.DupState
- @test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter(); AbstractIsMixed=true) == Enzyme.Compiler.MixedState
- @test Enzyme.Compiler.active_reg(Tuple{A,A} where A, Base.get_world_counter(), AbstractIsMixed=true) == Enzyme.Compiler.MixedState
+ @test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter(); AbstractIsMixed = true) == Enzyme.Compiler.MixedState
+ @test Enzyme.Compiler.active_reg(Tuple{A, A} where {A}, Base.get_world_counter(), AbstractIsMixed = true) == Enzyme.Compiler.MixedState
- @test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter(), AbstractIsMixed=true, justActive=true) == Enzyme.Compiler.MixedState
-end
\ No newline at end of file
+ @test Enzyme.Compiler.active_reg(Tuple, Base.get_world_counter(), AbstractIsMixed = true, justActive = true) == Enzyme.Compiler.MixedState
+end
diff --git a/test/rules/internal_rules.jl b/test/rules/internal_rules.jl
index ae853e6..6659ef9 100644
--- a/test/rules/internal_rules.jl
+++ b/test/rules/internal_rules.jl
@@ -16,7 +16,7 @@ function sorterrfn(t, x)
function lt(a, b)
return a.a < b.a
end
- return first(sortperm(t, lt=lt)) * x
+ return first(sortperm(t, lt = lt)) * x
end
@testset "Sort rules" begin
@@ -27,20 +27,20 @@ end
end
@test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1
- @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0)
+ @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1" = 1.0, var"2" = 2.0)
@test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1
@test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0
- @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0)
+ @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1" = 0.0, var"2" = 0.0)
@test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0
function f2(x)
a = [1.0, -3.0, -x, -2x, x]
- sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y))
+ sort!(a; rev = true, lt = (x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y))
return sum(a .* [1, 2, 3, 4, 5])
end
@test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3
- @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0)
+ @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1" = -3.0, var"2" = -6.0)
@test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3
function f3(x)
@@ -49,7 +49,7 @@ end
end
@test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0
- @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0)
+ @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1" = 1.0, var"2" = 2.0)
@test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0
@test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0
@@ -60,7 +60,7 @@ end
end
@test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5
- @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0)
+ @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1" = 1.5, var"2" = 3.0)
@test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5
@test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5
@test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0
@@ -136,75 +136,75 @@ end
return nothing
end
- A = rand(2,2)
+ A = rand(2, 2)
dA = [1.0 0.0; 0.0 0.0]
Enzyme.autodiff(
Enzyme.Reverse,
test2!,
- Enzyme.Duplicated(A,dA),
+ Enzyme.Duplicated(A, dA),
)
end
function tr_solv(A, B, uplo, trans, diag, idx)
- B = copy(B)
- LAPACK.trtrs!(uplo, trans, diag, A, B)
- return @inbounds B[idx]
+ B = copy(B)
+ LAPACK.trtrs!(uplo, trans, diag, A, B)
+ return @inbounds B[idx]
end
using FiniteDifferences
@testset "Reverse triangular solve" begin
- A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542]
- B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784]
+ A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542]
+ B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784]
for idx in 1:15
- for uplo in ('L', 'U')
- for diag in ('N', 'U')
- for trans in ('N', 'T')
- dA = zero(A)
- dB = zero(B)
- Enzyme.autodiff(Reverse, tr_solv, Duplicated(A, dA), Duplicated(B, dB), Const(uplo),Const(trans), Const(diag), Const(idx))
- fA = FiniteDifferences.grad(central_fdm(5, 1), A->tr_solv(A, B, uplo, trans, diag, idx), A)[1]
- fB = FiniteDifferences.grad(central_fdm(5, 1), B->tr_solv(A, B, uplo, trans, diag, idx), B)[1]
-
- if max(abs.(dA)...) >= 1e-10 || max(abs.(fA)...) >= 1e-10
- @test dA ≈ fA
- end
- if max(abs.(dB)...) >= 1e-10 || max(abs.(fB)...) >= 1e-10
- @test dB ≈ fB
- end
- end
- end
- end
+ for uplo in ('L', 'U')
+ for diag in ('N', 'U')
+ for trans in ('N', 'T')
+ dA = zero(A)
+ dB = zero(B)
+ Enzyme.autodiff(Reverse, tr_solv, Duplicated(A, dA), Duplicated(B, dB), Const(uplo), Const(trans), Const(diag), Const(idx))
+ fA = FiniteDifferences.grad(central_fdm(5, 1), A -> tr_solv(A, B, uplo, trans, diag, idx), A)[1]
+ fB = FiniteDifferences.grad(central_fdm(5, 1), B -> tr_solv(A, B, uplo, trans, diag, idx), B)[1]
+
+ if max(abs.(dA)...) >= 1.0e-10 || max(abs.(fA)...) >= 1.0e-10
+ @test dA ≈ fA
+ end
+ if max(abs.(dB)...) >= 1.0e-10 || max(abs.(fB)...) >= 1.0e-10
+ @test dB ≈ fB
+ end
+ end
+ end
+ end
end
end
function chol_lower0(x)
- c = copy(x)
- C, info = LinearAlgebra.LAPACK.potrf!('L', c)
- return c[2,1]
+ c = copy(x)
+ C, info = LinearAlgebra.LAPACK.potrf!('L', c)
+ return c[2, 1]
end
function chol_upper0(x)
- c = copy(x)
- C, info = LinearAlgebra.LAPACK.potrf!('U', c)
- return c[1,2]
+ c = copy(x)
+ C, info = LinearAlgebra.LAPACK.potrf!('U', c)
+ return c[1, 2]
end
@testset "Cholesky PotRF" begin
x = reshape([1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0], 4, 4)
- dL = zero(x)
- dL[2, 1] = 1.0
-
- @test Enzyme.gradient(Reverse, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
-
- @test Enzyme.gradient(Forward, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
-
- @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
-
- @test Enzyme.gradient(Forward, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
- @test Enzyme.gradient(Reverse, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
- @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
+ dL = zero(x)
+ dL[2, 1] = 1.0
+
+ @test Enzyme.gradient(Reverse, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
+
+ @test Enzyme.gradient(Forward, chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
+
+ @test FiniteDifferences.grad(central_fdm(5, 1), chol_lower0, x)[1] ≈ [0.05270807565639164 0.0 0.0 0.0; 1.0000000000000024 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
+
+ @test Enzyme.gradient(Forward, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
+ @test Enzyme.gradient(Reverse, chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
+ @test FiniteDifferences.grad(central_fdm(5, 1), chol_upper0, x)[1] ≈ [0.05270807565639728 0.9999999999999999 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
end
@@ -224,17 +224,17 @@ end
x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0]
for i in 1:size(x, 1)
for j in 1:size(x, 2)
- reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x)[1]
- forward_grad = Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)[1]
- finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1]
- @test reverse_grad ≈ finite_diff
- @test forward_grad ≈ finite_diff
-
- reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x)[1]
- forward_grad = Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)[1]
- finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1]
- @test reverse_grad ≈ finite_diff
- @test forward_grad ≈ finite_diff
+ reverse_grad = Enzyme.gradient(Reverse, x -> tchol_lower(x, i, j), x)[1]
+ forward_grad = Enzyme.gradient(Forward, x -> tchol_lower(x, i, j), x)[1]
+ finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_lower(x, i, j), x)[1]
+ @test reverse_grad ≈ finite_diff
+ @test forward_grad ≈ finite_diff
+
+ reverse_grad = Enzyme.gradient(Reverse, x -> tchol_upper(x, i, j), x)[1]
+ forward_grad = Enzyme.gradient(Forward, x -> tchol_upper(x, i, j), x)[1]
+ finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tchol_upper(x, i, j), x)[1]
+ @test reverse_grad ≈ finite_diff
+ @test forward_grad ≈ finite_diff
end
end
end
@@ -255,86 +255,86 @@ end
x = [1.0 0.13147601759884564 0.5282944836504488; 0.13147601759884564 1.0 0.18506733179093515; 0.5282944836504488 0.18506733179093515 1.0]
for i in 1:15
- B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3]
- reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B)[1]
- # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)[1]
- finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1]
- @test reverse_grad ≈ finite_diff
- # @test forward_grad ≈ finite_diff
-
- reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B)[1]
- # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B))[1]
- finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1]
- @test reverse_grad ≈ finite_diff
- # @test forward_grad ≈ finite_diff
-
- reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x)[1]
- #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)[1]
- finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1]
- @test reverse_grad ≈ finite_diff
- #@test forward_grad ≈ finite_diff
- #
- reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x)[1]
- #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)[1]
- finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1]
- @test reverse_grad ≈ finite_diff
- #@test forward_grad ≈ finite_diff
+ B = [3.1 2.7 5.9 2.4 1.6; 7.9 8.2 1.3 9.4 5.5; 4.7 2.9 9.8 7.1 4.3]
+ reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_lower(x, B, i)), B)[1]
+ # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_lower(x, B, i), B)[1]
+ finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_lower(x, B, i), B)[1]
+ @test reverse_grad ≈ finite_diff
+ # @test forward_grad ≈ finite_diff
+
+ reverse_grad = Enzyme.gradient(Reverse, Const(B -> tcholsolv_upper(x, B, i)), B)[1]
+ # forward_grad = Enzyme.gradient(Forward, B -> tcholsolv_upper(x, B, i), B))[1]
+ finite_diff = FiniteDifferences.grad(central_fdm(5, 1), B -> tcholsolv_upper(x, B, i), B)[1]
+ @test reverse_grad ≈ finite_diff
+ # @test forward_grad ≈ finite_diff
+
+ reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_lower(x, B, i)), x)[1]
+ #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_lower(x, B, i), x)[1]
+ finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_lower(x, B, i), x)[1]
+ @test reverse_grad ≈ finite_diff
+ #@test forward_grad ≈ finite_diff
+ #
+ reverse_grad = Enzyme.gradient(Reverse, Const(x -> tcholsolv_upper(x, B, i)), x)[1]
+ #forward_grad = Enzyme.gradient(Forward, x -> tcholsolv_upper(x, B, i), x)[1]
+ finite_diff = FiniteDifferences.grad(central_fdm(5, 1), x -> tcholsolv_upper(x, B, i), x)[1]
+ @test reverse_grad ≈ finite_diff
+ #@test forward_grad ≈ finite_diff
end
end
function two_blas(a, b)
- a = copy(a)
- @inline LinearAlgebra.LAPACK.potrf!('L', a)
- @inline LinearAlgebra.LAPACK.potrf!('L', b)
- return a[1,1] + b[1,1]
+ a = copy(a)
+ @inline LinearAlgebra.LAPACK.potrf!('L', a)
+ @inline LinearAlgebra.LAPACK.potrf!('L', b)
+ return a[1, 1] + b[1, 1]
end
@testset "Forward Mode runtime activity" begin
- a = [2.7 3.5; 7.4 9.2]
- da = [7.2 5.3; 4.7 2.9]
+ a = [2.7 3.5; 7.4 9.2]
+ da = [7.2 5.3; 4.7 2.9]
- b = [3.1 5.6; 13 19]
- db = [1.3 6.5; .13 .19]
-
- res = Enzyme.autodiff(Forward, two_blas, Duplicated(a, da), Duplicated(b, db))[1]
- @test res ≈ 2.5600654222812564
+ b = [3.1 5.6; 13 19]
+ db = [1.3 6.5; 0.13 0.19]
- a = [2.7 3.5; 7.4 9.2]
- da = [7.2 5.3; 4.7 2.9]
+ res = Enzyme.autodiff(Forward, two_blas, Duplicated(a, da), Duplicated(b, db))[1]
+ @test res ≈ 2.5600654222812564
- b = [3.1 5.6; 13 19]
- db = [1.3 6.5; .13 .19]
+ a = [2.7 3.5; 7.4 9.2]
+ da = [7.2 5.3; 4.7 2.9]
- res = Enzyme.autodiff(set_runtime_activity(Forward), two_blas, Duplicated(a, da), Duplicated(b, db))[1]
- @test res ≈ 2.5600654222812564
+ b = [3.1 5.6; 13 19]
+ db = [1.3 6.5; 0.13 0.19]
- a = [2.7 3.5; 7.4 9.2]
- da = [7.2 5.3; 4.7 2.9]
+ res = Enzyme.autodiff(set_runtime_activity(Forward), two_blas, Duplicated(a, da), Duplicated(b, db))[1]
+ @test res ≈ 2.5600654222812564
- b = [3.1 5.6; 13 19]
- db = [1.3 6.5; .13 .19]
+ a = [2.7 3.5; 7.4 9.2]
+ da = [7.2 5.3; 4.7 2.9]
- @test_throws Enzyme.Compiler.EnzymeNoDerivativeError Enzyme.autodiff(set_runtime_activity(Forward), two_blas, Duplicated(a, da), Duplicated(b, b))
+ b = [3.1 5.6; 13 19]
+ db = [1.3 6.5; 0.13 0.19]
+
+ @test_throws Enzyme.Compiler.EnzymeNoDerivativeError Enzyme.autodiff(set_runtime_activity(Forward), two_blas, Duplicated(a, da), Duplicated(b, b))
end
@testset "Cholesky" begin
- function symmetric_definite(n :: Int=10)
+ function symmetric_definite(n::Int = 10)
α = one(Float64)
- A = spdiagm(-1 => α * ones(n-1), 0 => 4 * ones(n), 1 => conj(α) * ones(n-1))
+ A = spdiagm(-1 => α * ones(n - 1), 0 => 4 * ones(n), 1 => conj(α) * ones(n - 1))
b = A * Float64[1:n;]
return A, b
end
function divdriver_NC(x, fact, b)
- res = fact\b
+ res = fact \ b
x .= res
return nothing
end
-
+
function ldivdriver_NC(x, fact, b)
- ldiv!(fact,b)
+ ldiv!(fact, b)
x .= b
return nothing
end
@@ -353,7 +353,7 @@ end
fact = cholesky(Symmetric(A))
divdriver_NC(x, fact, b)
end
-
+
function ldivdriver(x, A, b)
fact = cholesky(A)
ldivdriver_NC(x, fact, b)
@@ -496,13 +496,13 @@ end
dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n))
db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n))
dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n))
- Enzyme.autodiff(
- Reverse,
- driver,
- dx,
- dA,
- db
- )
+ Enzyme.autodiff(
+ Reverse,
+ driver,
+ dx,
+ dA,
+ db
+ )
for i in 1:n
adJ[i, :] .= db.dval[i]
end
@@ -510,13 +510,13 @@ end
end
function Jdxdb(driver, A, b)
- x = A\b
+ x = A \ b
dA = zeros(size(A))
db = zeros(length(b))
J = zeros(length(b), length(b))
for i in 1:length(b)
db[i] = 1.0
- dx = A\db
+ dx = A \ db
db[i] = 0.0
J[i, :] = dx
end
@@ -528,35 +528,35 @@ end
J = zeros(length(b), length(b))
for i in 1:length(b)
db[i] = 1.0
- dx = A\db
+ dx = A \ db
db[i] = 0.0
J[i, :] = dx
end
return J
end
-
+
@testset "Testing $op" for (op, driver, driver_NC) in (
- (:\, divdriver, divdriver_NC),
- (:\, divdriver_herm, divdriver_NC),
- (:\, divdriver_sym, divdriver_NC),
- (:ldiv!, ldivdriver, ldivdriver_NC),
- (:ldiv!, ldivdriver_herm, ldivdriver_NC),
- (:ldiv!, ldivdriver_sym, ldivdriver_NC)
- )
+ (:\, divdriver, divdriver_NC),
+ (:\, divdriver_herm, divdriver_NC),
+ (:\, divdriver_sym, divdriver_NC),
+ (:ldiv!, ldivdriver, ldivdriver_NC),
+ (:ldiv!, ldivdriver_herm, ldivdriver_NC),
+ (:ldiv!, ldivdriver_sym, ldivdriver_NC),
+ )
A, b = symmetric_definite(10)
n = length(b)
A = Matrix(A)
x = zeros(n)
x = driver(x, A, b)
- fdm = forward_fdm(2, 1);
+ fdm = forward_fdm(2, 1)
function b_one(b)
_x = zeros(length(b))
- driver(_x,A,b)
+ driver(_x, A, b)
return _x
end
- fdJ = op==:\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing
+ fdJ = op == :\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing
fwdJ = fwdJdxdb(driver, A, b)
revJ = revJdxdb(driver, A, b)
batchedrevJ = batchedrevJdxdb(driver, A, b)
@@ -587,35 +587,35 @@ end
end
A = [1.3 0.5; 0.5 1.5]
- b = [1., 2.]
+ b = [1.0, 2.0]
dA = zero(A)
Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b))
# dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A)[1]
- dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1]
+ dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A -> h(A, b), A)[1]
@test isapprox(dA, dA_fd)
end
end
function chol_upper(x)
- x = reshape(x, 4, 4)
- x = parent(cholesky(Hermitian(x)).U)
- x = convert(typeof(x), UpperTriangular(x))
- return x[1,2]
+ x = reshape(x, 4, 4)
+ x = parent(cholesky(Hermitian(x)).U)
+ x = convert(typeof(x), UpperTriangular(x))
+ return x[1, 2]
end
@testset "Cholesky upper triangular v1" begin
- x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0]
+ x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0]
@test Enzyme.gradient(Forward, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
@test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
end
-
+
using EnzymeTestUtils
@testset "Linear solve for triangular matrices" begin
@testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular),
- TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))
+ TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))
n = sizeB[1]
M = rand(TE, n, n)
B = rand(TE, sizeB...)
@@ -625,8 +625,8 @@ using EnzymeTestUtils
_A = T(A)
f!(Y, A, B, ::T) where {T} = ldiv!(Y, T(A), B)
for TY in (Const, Duplicated, BatchDuplicated),
- TM in (Const, Duplicated, BatchDuplicated),
- TB in (Const, Duplicated, BatchDuplicated)
+ TM in (Const, Duplicated, BatchDuplicated),
+ TB in (Const, Duplicated, BatchDuplicated)
are_activities_compatible(Const, TY, TM, TB) || continue
test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const); atol = 1.0e-5, rtol = 1.0e-5)
end
@@ -688,16 +688,16 @@ end
@test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (15.0,)
@test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) ==
- ((var"1" = 374.99999999999994, var"2" = 749.9999999999999),)
+ ((var"1" = 374.99999999999994, var"2" = 749.9999999999999),)
@test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) ==
- ((var"1"=25.0, var"2"=50.0),)
+ ((var"1" = 25.0, var"2" = 50.0),)
@test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) ==
- ((var"1"=15.0, var"2"=30.0),)
+ ((var"1" = 15.0, var"2" = 30.0),)
+
+ @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),)
+ @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),)
+ @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),)
- @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),)
- @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),)
- @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),)
-
# Batch active rule isnt setup
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((375.0,750.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),)
@@ -731,19 +731,19 @@ end
@test Enzyme.autodiff(Forward, f4, Duplicated(0.12, 1.0)) == (0,)
@test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) ==
- ((var"1"=25.0, var"2"=50.0),)
+ ((var"1" = 25.0, var"2" = 50.0),)
@test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) ==
- ((var"1"=25.0, var"2"=50.0),)
+ ((var"1" = 25.0, var"2" = 50.0),)
@test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) ==
- ((var"1"=75.0, var"2"=150.0),)
+ ((var"1" = 75.0, var"2" = 150.0),)
@test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) ==
- ((var"1"=0.0, var"2"=0.0),)
+ ((var"1" = 0.0, var"2" = 0.0),)
+
+ @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((25.0,),)
+ @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),)
+ @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((75.0,),)
+ @test Enzyme.autodiff(Reverse, f4, Active, Active(0.12)) == ((0.0,),)
- @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((25.0,),)
- @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),)
- @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((75.0,),)
- @test Enzyme.autodiff(Reverse, f4, Active, Active(0.12)) == ((0.0,),)
-
# Batch active rule isnt setup
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(1.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),)
@@ -760,31 +760,33 @@ function test_sparse(M, v, α, β)
end
- for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
- Tα in (Const, Active), Tβ in (Const, Active)
+ for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
+ Tα in (Const, Active), Tβ in (Const, Active)
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))
end
- for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
- Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
+ for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
+ Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, TM, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end
- test_reverse(LinearAlgebra.mul!, Const, (C, Const), (M, Const), (v, Const), (α, Active), (β, Active))
+ return test_reverse(LinearAlgebra.mul!, Const, (C, Const), (M, Const), (v, Const), (α, Active), (β, Active))
end
@testset "SparseArrays spmatvec reverse rule" begin
Ts = ComplexF64
- M0 = [0.0 1.50614;
- 0.0 -0.988357;
- 0.0 0.0]
+ M0 = [
+ 0.0 1.50614;
+ 0.0 -0.988357;
+ 0.0 0.0
+ ]
- M = SparseMatrixCSC((M0 .+ 2im*M0))
+ M = SparseMatrixCSC((M0 .+ 2im * M0))
v = rand(Ts, 2)
α = rand(Ts)
β = rand(Ts)
@@ -814,17 +816,19 @@ end
@testset "SparseArrays spmatmat reverse rule" begin
Ts = ComplexF64
- M0 = [0.0 1.50614;
- 0.0 -0.988357;
- 0.0 0.0]
+ M0 = [
+ 0.0 1.50614;
+ 0.0 -0.988357;
+ 0.0 0.0
+ ]
- M = SparseMatrixCSC((M0 .+ 2im*M0))
+ M = SparseMatrixCSC((M0 .+ 2im * M0))
v = rand(Ts, 2, 2)
α = rand(Ts)
β = rand(Ts)
- # Now all the code paths are already tested in the vector case so we just make sure that
+ # Now all the code paths are already tested in the vector case so we just make sure that
# general matrix multiplication works
test_sparse(M, v, α, β)
diff --git a/test/rules/kwrrules.jl b/test/rules/kwrrules.jl
index 044273d..9d869b1 100644
--- a/test/rules/kwrrules.jl
+++ b/test/rules/kwrrules.jl
@@ -5,7 +5,7 @@ using Enzyme.EnzymeRules
using Test
function f_kw(x; kwargs...)
- x^2
+ return x^2
end
import .EnzymeRules: augmented_primal, reverse
@@ -24,22 +24,22 @@ function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw)}, dret::Active,
# TODO do we want kwargs here?
@assert length(overwritten(config)) == 2
if needs_primal(config)
- return (10+2*x.val*dret.val,)
+ return (10 + 2 * x.val * dret.val,)
else
- return (100+2*x.val*dret.val,)
+ return (100 + 2 * x.val * dret.val,)
end
end
@test Enzyme.autodiff(Enzyme.Reverse, f_kw, Active(2.0))[1][1] ≈ 104.0
# TODO: autodiff wrapper with kwargs support
-g(x, y) = f_kw(x; val=y)
+g(x, y) = f_kw(x; val = y)
@test Enzyme.autodiff(Enzyme.Reverse, g, Active(2.0), Const(42.0))[1][1] ≈ 104.0
function f_kw2(x; kwargs...)
- x^2
+ return x^2
end
function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw2)}, ::Type{<:Active}, x::Active)
@@ -52,22 +52,22 @@ end
function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw2)}, dret::Active, tape, x::Active)
if needs_primal(config)
- return (10+2*x.val*dret.val,)
+ return (10 + 2 * x.val * dret.val,)
else
- return (100+2*x.val*dret.val,)
+ return (100 + 2 * x.val * dret.val,)
end
end
# Test that this errors due to missing kwargs in rule definition
-g2(x, y) = f_kw2(x; val=y)
+g2(x, y) = f_kw2(x; val = y)
@test_throws MethodError autodiff(Reverse, g2, Active(2.0), Const(42.0))[1][1]
-function f_kw3(x; val=nothing)
- x^2
+function f_kw3(x; val = nothing)
+ return x^2
end
-function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw3)}, ::Type{<:Active}, x::Active; dval=nothing)
+function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw3)}, ::Type{<:Active}, x::Active; dval = nothing)
if needs_primal(config)
return AugmentedReturn(func.val(x.val), nothing, nothing)
else
@@ -75,20 +75,20 @@ function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw3)},
end
end
-function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw3)}, dret::Active, tape, x::Active; dval=nothing)
+function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw3)}, dret::Active, tape, x::Active; dval = nothing)
if needs_primal(config)
- return (10+2*x.val*dret.val,)
+ return (10 + 2 * x.val * dret.val,)
else
- return (100+2*x.val*dret.val,)
+ return (100 + 2 * x.val * dret.val,)
end
end
# Test that this errors due to missing kwargs in rule definition
-g3(x, y) = f_kw3(x; val=y)
+g3(x, y) = f_kw3(x; val = y)
@test_throws MethodError autodiff(Reverse, g3, Active(2.0), Const(42.0))[1][1]
-function f_kw4(x; y=2.0)
- x*y
+function f_kw4(x; y = 2.0)
+ return x * y
end
function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f_kw4)}, ::Type{<:Active}, x::Active; y)
@@ -102,7 +102,7 @@ end
function reverse(config::RevConfigWidth{1}, ::Const{typeof(f_kw4)}, dret::Active, tape, x::Active; y)
@assert length(overwritten(config)) == 2
- return (1000*y+2*x.val*dret.val,)
+ return (1000 * y + 2 * x.val * dret.val,)
end
# Test that this errors due to missing kwargs in rule definition
@@ -115,18 +115,20 @@ struct Closure2
str::String
end
-function (cl::Closure2)(x; width=7)
+function (cl::Closure2)(x; width = 7)
val = cl.v[1] * x * width
cl.v[1] = 0.0
return val
end
function wrapclos(cl, x)
- cl(x; width=9)
+ return cl(x; width = 9)
end
-function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{Closure2},
- ::Type{<:Active}, args::Vararg{Active,N}; width=7) where {N}
+function EnzymeRules.augmented_primal(
+ config::RevConfigWidth{1}, func::Const{Closure2},
+ ::Type{<:Active}, args::Vararg{Active, N}; width = 7
+ ) where {N}
vec = copy(func.val.v)
pval = func.val(args[1].val)
primal = if EnzymeRules.needs_primal(config)
@@ -137,8 +139,10 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{Clo
return AugmentedReturn(primal, nothing, vec)
end
-function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{Closure2},
- dret::Active, tape, args::Vararg{Active,N}; width=7) where {N}
+function EnzymeRules.reverse(
+ config::RevConfigWidth{1}, func::Const{Closure2},
+ dret::Active, tape, args::Vararg{Active, N}; width = 7
+ ) where {N}
dargs = ntuple(Val(N)) do i
7 * args[1].val * dret.val + tape[1] * 1000 + width * 100000
end
diff --git a/test/rules/kwrules.jl b/test/rules/kwrules.jl
index 9761c23..f4feb6b 100644
--- a/test/rules/kwrules.jl
+++ b/test/rules/kwrules.jl
@@ -7,55 +7,55 @@ using Test
import .EnzymeRules: forward
function f_kw(x; kwargs...)
- x^2
+ return x^2
end
function forward(config, ::Const{typeof(f_kw)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; kwargs...)
- return 10+2*x.val*x.dval
+ return 10 + 2 * x.val * x.dval
end
@test autodiff(Forward, f_kw, Duplicated(2.0, 1.0))[1] ≈ 14.0
# TODO: autodiff wrapper with kwargs support
-g(x, y) = f_kw(x; val=y)
+g(x, y) = f_kw(x; val = y)
@test autodiff(Forward, g, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0
function f_kw2(x; kwargs...)
- x^2
+ return x^2
end
function forward(config, ::Const{typeof(f_kw2)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated)
- return 10+2*x.val*x.dval
+ return 10 + 2 * x.val * x.dval
end
# Test that this errors due to missing kwargs in rule definition
-g2(x, y) = f_kw2(x; val=y)
+g2(x, y) = f_kw2(x; val = y)
@test_throws MethodError autodiff(Forward, g2, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0
-function f_kw3(x; val=nothing)
- x^2
+function f_kw3(x; val = nothing)
+ return x^2
end
-function forward(config, ::Const{typeof(f_kw3)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; dval=nothing)
- return 10+2*x.val*x.dval
+function forward(config, ::Const{typeof(f_kw3)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; dval = nothing)
+ return 10 + 2 * x.val * x.dval
end
# Test that this errors due to missing kwargs in rule definition
-g3(x, y) = f_kw3(x; val=y)
+g3(x, y) = f_kw3(x; val = y)
@test_throws MethodError autodiff(Forward, g3, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0
-function f_kw4(x; y=2.0)
- x*y
+function f_kw4(x; y = 2.0)
+ return x * y
end
function forward(config, ::Const{typeof(f_kw4)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated; y)
- return 1000*y+2*x.val*x.dval
+ return 1000 * y + 2 * x.val * x.dval
end
# Test that this errors due to missing kwargs in rule definition
g4(x, y) = f_kw4(x; y)
-@test autodiff(Forward, g4, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 42004.0
+@test autodiff(Forward, g4, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 42004.0
@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Forward, g4, Duplicated(2.0, 1.0), Duplicated(42.0, 1.0))[1]
end # KWForwardRules
diff --git a/test/rules/rrules.jl b/test/rules/rrules.jl
index 41c84fe..562ce01 100644
--- a/test/rules/rrules.jl
+++ b/test/rules/rrules.jl
@@ -8,8 +8,8 @@ using Test
f(x) = x^2
function f_ip(x)
- x[1] *= x[1]
- return nothing
+ x[1] *= x[1]
+ return nothing
end
import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig
@@ -25,9 +25,9 @@ end
function reverse(config::RevConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active)
if needs_primal(config)
- return (10+2*x.val*dret.val,)
+ return (10 + 2 * x.val * dret.val,)
else
- return (100+2*x.val*dret.val,)
+ return (100 + 2 * x.val * dret.val,)
end
end
@@ -50,13 +50,13 @@ end
@testset "Custom Reverse Rules" begin
@test Enzyme.autodiff(Enzyme.Reverse, f, Active(2.0))[1][1] ≈ 104.0
- @test Enzyme.autodiff(Enzyme.Reverse, x->f(x)^2, Active(2.0))[1][1] ≈ 42.0
+ @test Enzyme.autodiff(Enzyme.Reverse, x -> f(x)^2, Active(2.0))[1][1] ≈ 42.0
x = [2.0]
dx = [1.0]
-
+
Enzyme.autodiff(Enzyme.Reverse, f_ip, Duplicated(x, dx))
-
+
@test x ≈ [4.0]
@test dx ≈ [102.0]
end
@@ -70,12 +70,12 @@ function augmented_primal(config::RevConfigWidth{2}, func::Const{typeof(f)}, ::T
end
function reverse(config::RevConfigWidth{2}, ::Const{typeof(f)}, dret::Active, tape, x::Active)
- return ((10+2*x.val*dret.val,100+2*x.val*dret.val,))
+ return ((10 + 2 * x.val * dret.val, 100 + 2 * x.val * dret.val))
end
function fip_2(out, in)
out[] = f(in[])
- nothing
+ return nothing
end
@testset "Batch ActiveReverse Rules" begin
@@ -88,19 +88,19 @@ end
end
function alloc_sq(x)
- return Ref(x*x)
+ return Ref(x * x)
end
function h(x)
- alloc_sq(x)[]
+ return alloc_sq(x)[]
end
function h2(x)
y = alloc_sq(x)[]
- y * y
+ return y * y
end
-function augmented_primal(config, func::Const{typeof(alloc_sq)}, ::Type{<:Annotation}, x::Active{T}) where T
+function augmented_primal(config, func::Const{typeof(alloc_sq)}, ::Type{<:Annotation}, x::Active{T}) where {T}
primal = nothing
# primal
if needs_primal(config)
@@ -114,22 +114,22 @@ function augmented_primal(config, func::Const{typeof(alloc_sq)}, ::Type{<:Annota
if needs_shadow(config)
shadow = shadref
end
-
+
return AugmentedReturn(primal, shadow, shadref)
end
function reverse(config, ::Const{typeof(alloc_sq)}, ::Type{<:Annotation}, tape, x::Active)
if needs_primal(config)
- return (10*2*x.val*tape[],)
+ return (10 * 2 * x.val * tape[],)
else
- return (1000*2*x.val*tape[],)
+ return (1000 * 2 * x.val * tape[],)
end
end
@testset "Shadow" begin
@test Enzyme.autodiff(Reverse, h, Active(3.0)) == ((6000.0,),)
- @test Enzyme.autodiff(ReverseWithPrimal, h, Active(3.0)) == ((60.0,), 9.0)
- @test Enzyme.autodiff(Reverse, h2, Active(3.0)) == ((1080.0,),)
+ @test Enzyme.autodiff(ReverseWithPrimal, h, Active(3.0)) == ((60.0,), 9.0)
+ @test Enzyme.autodiff(Reverse, h2, Active(3.0)) == ((1080.0,),)
end
q(x) = x^2
@@ -146,9 +146,9 @@ function reverse(config::RevConfigWidth{1}, ::Const{typeof(q)}, dret::Active, ta
@test tape[1][] == 2.0
@test tape[2][] == 3.4
if needs_primal(config)
- return (10+2*x.val*dret.val,)
+ return (10 + 2 * x.val * dret.val,)
else
- return (100+2*x.val*dret.val,)
+ return (100 + 2 * x.val * dret.val,)
end
end
@@ -159,11 +159,11 @@ end
foo(x::Complex) = 2x
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(foo)},
- ::Type{<:Active},
- x
-)
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(foo)},
+ ::Type{<:Active},
+ x
+ )
r = func.val(x.val)
if EnzymeRules.needs_primal(config)
primal = func.val(x.val)
@@ -180,35 +180,35 @@ function EnzymeRules.augmented_primal(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(foo)},
- dret,
- tape,
- y
-)
- return (dret.val+13.0im,)
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(foo)},
+ dret,
+ tape,
+ y
+ )
+ return (dret.val + 13.0im,)
end
@testset "Complex values" begin
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(foo)}, Active, Active{ComplexF64})
- z = 1.0+3im
+ z = 1.0 + 3im
grad_u = rev(Const(foo), Active(z), 1.0 + 0.0im, fwd(Const(foo), Active(z))[1])[1][1]
- @test grad_u ≈ 1.0+13.0im
+ @test grad_u ≈ 1.0 + 13.0im
end
_scalar_dot(x, y) = conj(x) * y
-function _dot(X::StridedArray{T}, Y::StridedArray{T}) where {T<:Union{Real,Complex}}
+function _dot(X::StridedArray{T}, Y::StridedArray{T}) where {T <: Union{Real, Complex}}
return mapreduce(_scalar_dot, +, X, Y)
end
function augmented_primal(
- config::RevConfigWidth{1},
- func::Const{typeof(_dot)},
- ::Type{<:Union{Const,Active}},
- X::Duplicated{<:StridedArray{T}},
- Y::Duplicated{<:StridedArray{T}},
-) where {T<:Union{Real,Complex}}
+ config::RevConfigWidth{1},
+ func::Const{typeof(_dot)},
+ ::Type{<:Union{Const, Active}},
+ X::Duplicated{<:StridedArray{T}},
+ Y::Duplicated{<:StridedArray{T}},
+ ) where {T <: Union{Real, Complex}}
r = func.val(X.val, Y.val)
primal = needs_primal(config) ? r : nothing
shadow = needs_shadow(config) ? zero(r) : nothing
@@ -217,13 +217,13 @@ function augmented_primal(
end
function reverse(
- ::RevConfigWidth{1},
- ::Const{typeof(_dot)},
- dret::Union{Active,Type{<:Const}},
- tape,
- X::Duplicated{<:StridedArray{T}},
- Y::Duplicated{<:StridedArray{T}},
-) where {T<:Union{Real,Complex}}
+ ::RevConfigWidth{1},
+ ::Const{typeof(_dot)},
+ dret::Union{Active, Type{<:Const}},
+ tape,
+ X::Duplicated{<:StridedArray{T}},
+ Y::Duplicated{<:StridedArray{T}},
+ ) where {T <: Union{Real, Complex}}
if !(dret isa Type{<:Const})
Xtape, Ytape = tape
X.dval .+= dret.val .* Ytape
@@ -236,8 +236,8 @@ end
@testset "Correct primal computation for custom `dot`" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
n = 10
- x, y = randn(T, n), randn(T, n);
- ∂x, ∂y = map(zero, (x, y));
+ x, y = randn(T, n), randn(T, n)
+ ∂x, ∂y = map(zero, (x, y))
val_exp = _dot(x, y)
_, val = autodiff(
ReverseWithPrimal, _dot, Const, Duplicated(x, ∂x), Duplicated(y, ∂y),
@@ -246,9 +246,9 @@ end
end
end
-function cmyfunc!(y, x)
+function cmyfunc!(y, x)
y .= x
- nothing
+ return nothing
end
function cprimal(x0, y0)
@@ -261,17 +261,21 @@ function cprimal(x0, y0)
return @inbounds x[1]
end
-function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const},
- y::Duplicated, x::Duplicated)
+function EnzymeRules.augmented_primal(
+ config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const},
+ y::Duplicated, x::Duplicated
+ )
cmyfunc!(y.val, x.val)
tape = (copy(x.val), 3)
return AugmentedReturn(nothing, nothing, tape)
end
const seen = Set()
-function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, tape,
- y::Duplicated, x::Duplicated)
- xval = tape[1]
+function EnzymeRules.reverse(
+ config::RevConfigWidth{1}, func::Const{typeof(cmyfunc!)}, ::Type{<:Const}, tape,
+ y::Duplicated, x::Duplicated
+ )
+ xval = tape[1]
p = pointer(xval)
@assert !in(p, seen)
push!(seen, p)
@@ -287,11 +291,13 @@ end
end
function remultr(arg)
- arg * arg
+ return arg * arg
end
-function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(remultr)},
- ::Type{<:Active}, args::Vararg{Active,N}) where {N}
+function EnzymeRules.augmented_primal(
+ config::RevConfigWidth{1}, func::Const{typeof(remultr)},
+ ::Type{<:Active}, args::Vararg{Active, N}
+ ) where {N}
primal = if EnzymeRules.needs_primal(config)
func.val(...*[Comment body truncated]* |
7d5ebfa to
cb17c75
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2619 +/- ##
==========================================
- Coverage 74.44% 74.33% -0.12%
==========================================
Files 57 57
Lines 17941 17960 +19
==========================================
- Hits 13357 13351 -6
- Misses 4584 4609 +25 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
vchuravy
commented
Sep 23, 2025
Member
Author
|
Speeds up CI by roughly 2x 38min -> 20 min |
giordano
approved these changes
Sep 24, 2025
Contributor
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/18008959658/artifacts/4104661469. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.